From cd8f0692aac00b2c86a3ae2f972582f001be52d6 Mon Sep 17 00:00:00 2001 From: st-- Date: Sun, 10 Apr 2022 11:09:08 +0200 Subject: [PATCH] reactivate AD tests: mean functions (#313) * reactivate mean function AD tests * extend mean function tests to ColVecs/RowVecs * unify testcases * remove rrules and ChainRulesCore Co-authored-by: David Widmann Co-authored-by: willtebbutt --- Project.toml | 4 +--- src/AbstractGPs.jl | 1 - src/mean_function.jl | 7 +------ test/Project.toml | 2 -- test/mean_function.jl | 45 +++++++++++++------------------------------ test/runtests.jl | 1 - test/test_util.jl | 30 ++++++++--------------------- 7 files changed, 23 insertions(+), 67 deletions(-) diff --git a/Project.toml b/Project.toml index 403e508b..7897c5e1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,9 @@ name = "AbstractGPs" uuid = "99985d1d-32ba-4be9-9821-2ec096f28918" authors = ["JuliaGaussianProcesses Team"] -version = "0.5.11" +version = "0.5.12" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" @@ -19,7 +18,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "1" Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13" IrrationalConstants = "0.1" diff --git a/src/AbstractGPs.jl b/src/AbstractGPs.jl index 9792cbea..acb3c577 100644 --- a/src/AbstractGPs.jl +++ b/src/AbstractGPs.jl @@ -1,6 +1,5 @@ module AbstractGPs -using ChainRulesCore using Distributions using FillArrays using LinearAlgebra diff --git a/src/mean_function.jl b/src/mean_function.jl index 6eacec36..e691b154 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -12,11 +12,6 @@ This is an AbstractGPs-internal workaround for AD issues; ideally we would just """ _map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x)) -function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector) - map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent()) - return _map_meanfunction(m, x), map_ZeroMean_pullback -end - ZeroMean() = ZeroMean{Float64}() """ @@ -40,4 +35,4 @@ struct CustomMean{Tf} <: MeanFunction f::Tf end -_map_meanfunction(f::CustomMean, x::AbstractVector) = map(f.f, x) +_map_meanfunction(m::CustomMean, x::AbstractVector) = map(m.f, x) diff --git a/test/Project.toml b/test/Project.toml index 57222240..5051d678 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,4 @@ [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -14,7 +13,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ChainRulesCore = "1" Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" Documenter = "0.24, 0.25, 0.26, 0.27" FillArrays = "0.11, 0.12, 0.13" diff --git a/test/mean_function.jl b/test/mean_function.jl index a699ac29..511ea687 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -5,42 +5,23 @@ xD_colvecs = ColVecs(randn(rng, D, N)) xD_rowvecs = RowVecs(randn(rng, N, D)) - @testset "ZeroMean" begin - m = ZeroMean{Float64}() + zero_mean_testcase = (; mean_function=ZeroMean(), calc_expected=_ -> zeros(N)) - for x in [x1, xD_colvecs, xD_rowvecs] - @test AbstractGPs._map_meanfunction(m, x) == zeros(N) - #differentiable_mean_function_tests(m, randn(rng, N), x) - - # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but - # currently ChainRulesTestUtils isn't up to handling this, so this will have to do - # for now. - y, pb = rrule(AbstractGPs._map_meanfunction, m, x) - @test y == AbstractGPs._map_meanfunction(m, x) - Δmap, Δf, Δx = pb(randn(rng, N)) - @test iszero(Δmap) - @test iszero(Δf) - @test iszero(Δx) - end - end - - @testset "ConstMean" begin - c = randn(rng) - m = ConstMean(c) - - for x in [x1, xD_colvecs, xD_rowvecs] - @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) - #differentiable_mean_function_tests(m, randn(rng, N), x) - end - end + c = randn(rng) + const_mean_testcase = (; mean_function=ConstMean(c), calc_expected=_ -> fill(c, N)) - @testset "CustomMean" begin - foo_mean = x -> sum(abs2, x) - m = CustomMean(foo_mean) + foo_mean = x -> sum(abs2, x) + custom_mean_testcase = (; + mean_function=CustomMean(foo_mean), calc_expected=x -> map(foo_mean, x) + ) + @testset "$(typeof(testcase.mean_function))" for testcase in [ + zero_mean_testcase, const_mean_testcase, custom_mean_testcase + ] for x in [x1, xD_colvecs, xD_rowvecs] - @test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x) - #differentiable_mean_function_tests(m, randn(rng, N), x) + m = testcase.mean_function + @test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x) + differentiable_mean_function_tests(rng, m, x) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 4166a913..415c0032 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,6 @@ using AbstractGPs: TestUtils using Documenter -using ChainRulesCore using Distributions: MvNormal, PDMat, loglikelihood, Distributions using FillArrays using FiniteDifferences diff --git a/test/test_util.jl b/test/test_util.jl index e2aa2c83..e525b1fa 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -73,8 +73,9 @@ end Test _very_ basic consistency properties of the mean function `m`. """ function mean_function_tests(m::MeanFunction, x::AbstractVector) - @test AbstractGPs._map_meanfunction(m, x) isa AbstractVector - @test length(ew(m, x)) == length(x) + mean = AbstractGPs._map_meanfunction(m, x) + @test mean isa AbstractVector + @test length(mean) == length(x) end """ @@ -87,34 +88,19 @@ end Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximately correct. """ function differentiable_mean_function_tests( - m::MeanFunction, - ȳ::AbstractVector{<:Real}, - x::AbstractVector{<:Real}; - rtol=_rtol, - atol=_atol, + m::MeanFunction, ȳ::AbstractVector, x::AbstractVector; rtol=_rtol, atol=_atol ) # Run forward tests. mean_function_tests(m, x) # Check adjoint. @assert length(ȳ) == length(x) - return adjoint_test(x -> ew(m, x), ȳ, x; rtol=rtol, atol=atol) + adjoint_test( + x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol + ) + return nothing end -# function differentiable_mean_function_tests( -# m::MeanFunction, -# ȳ::AbstractVector{<:Real}, -# x::ColVecs{<:Real}; -# rtol=_rtol, -# atol=_atol, -# ) -# # Run forward tests. -# mean_function_tests(m, x) - -# @assert length(ȳ) == length(x) -# adjoint_test(X->ew(m, ColVecs(X)), ȳ, x.X; rtol=rtol, atol=atol) -# end - function differentiable_mean_function_tests( rng::AbstractRNG, m::MeanFunction, x::AbstractVector; rtol=_rtol, atol=_atol )