From 3aced77d193f09bb32870f3859dd9d884d9d344e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 30 Aug 2021 09:37:26 +0200 Subject: [PATCH] Add ChainRules adjoints (#106) * Add ChainRules adjoints * Move differentiation rules to a separate file * Update to new test syntax * `rand_tangent` is fixed upstream * Add support for ChainRulesCore 0.10 * Fix definition of chainrule for `poislogpdf` * Use ChainRulesCore 1 * Only support ChainRulesCore 1 --- Project.toml | 6 +++- src/StatsFuns.jl | 3 ++ src/chainrules.jl | 78 +++++++++++++++++++++++++++++++++++++++++++++ src/distrs/chisq.jl | 2 +- src/distrs/pois.jl | 2 +- src/distrs/tdist.jl | 2 +- test/chainrules.jl | 56 ++++++++++++++++++++++++++++++++ test/runtests.jl | 2 +- 8 files changed, 146 insertions(+), 5 deletions(-) create mode 100644 src/chainrules.jl create mode 100644 test/chainrules.jl diff --git a/Project.toml b/Project.toml index 5cdb389..91b946e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.9.9" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -10,6 +11,7 @@ Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] +ChainRulesCore = "1" IrrationalConstants = "0.1" LogExpFunctions = "0.3" Reexport = "1" @@ -18,8 +20,10 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1.0" julia = "1" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ForwardDiff", "Test"] +test = ["ChainRulesTestUtils", "ForwardDiff", "Random", "Test"] diff --git a/src/StatsFuns.jl b/src/StatsFuns.jl index f5f460e..2ff3414 100644 --- a/src/StatsFuns.jl +++ b/src/StatsFuns.jl @@ -5,6 +5,7 @@ module StatsFuns using Base: Math.@horner using Reexport using SpecialFunctions +import ChainRulesCore # reexports @reexport using IrrationalConstants: @@ -257,4 +258,6 @@ include(joinpath("distrs", "pois.jl")) include(joinpath("distrs", "tdist.jl")) include(joinpath("distrs", "srdist.jl")) +include("chainrules.jl") + end # module diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 0000000..f405c21 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,78 @@ +ChainRulesCore.@scalar_rule( + betalogpdf(α::Real, β::Real, x::Number), + @setup(z = digamma(α + β)), + ( + log(x) + z - digamma(α), + log1p(-x) + z - digamma(β), + (α - 1) / x + (1 - β) / (1 - x), + ), +) + +ChainRulesCore.@scalar_rule( + binomlogpdf(n::Real, p::Real, k::Real), + @setup(z = digamma(n - k + 1)), + ( + digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), + (k / p - n) / (1 - p), + z - digamma(k + 1) + logit(p), + ), +) + +ChainRulesCore.@scalar_rule( + chisqlogpdf(k::Real, x::Number), + @setup(hk = k / 2), + ( + (log(x) - logtwo - digamma(hk)) / 2, + (hk - 1) / x - one(hk) / 2, + ), +) + +ChainRulesCore.@scalar_rule( + fdistlogpdf(ν1::Real, ν2::Real, x::Number), + @setup( + xν1 = x * ν1, + temp1 = xν1 + ν2, + a = (x - 1) / temp1, + νsum = ν1 + ν2, + di = digamma(νsum / 2), + ), + ( + (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, + (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, + ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, + ), +) + +ChainRulesCore.@scalar_rule( + gammalogpdf(k::Real, θ::Real, x::Number), + @setup( + invθ = inv(θ), + xoθ = invθ * x, + z = xoθ - k, + ), + ( + log(xoθ) - digamma(k), + invθ * z, + - (1 + z) / x, + ), +) + +ChainRulesCore.@scalar_rule( + poislogpdf(λ::Number, x::Number), + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)), +) + +ChainRulesCore.@scalar_rule( + tdistlogpdf(ν::Real, x::Number), + @setup( + νp1 = ν + 1, + xsq = x^2, + invν = inv(ν), + a = xsq * invν, + b = νp1 / (ν + xsq), + ), + ( + (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, + - x * b, + ), +) diff --git a/src/distrs/chisq.jl b/src/distrs/chisq.jl index 606923e..25a069d 100644 --- a/src/distrs/chisq.jl +++ b/src/distrs/chisq.jl @@ -21,5 +21,5 @@ end # logpdf for numbers with generic types function chisqlogpdf(k::Real, x::Number) hk = k / 2 # half k - -hk * log(oftype(hk, 2)) - loggamma(hk) + (hk - 1) * log(x) - x / 2 + -hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2 end diff --git a/src/distrs/pois.jl b/src/distrs/pois.jl index 1c1ad3b..3caffc2 100644 --- a/src/distrs/pois.jl +++ b/src/distrs/pois.jl @@ -27,4 +27,4 @@ function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer -λ else -lstirling_asym(x + 1) -=# \ No newline at end of file +=# diff --git a/src/distrs/tdist.jl b/src/distrs/tdist.jl index 322a103..5dc83e9 100644 --- a/src/distrs/tdist.jl +++ b/src/distrs/tdist.jl @@ -16,4 +16,4 @@ import .RFunctions: tdistpdf(ν::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt(ν * pi) * gamma(ν / 2)) * (1 + x^2 / ν)^(-(ν + 1) / 2) # logpdf for numbers with generic types -tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log(1 + x^2 / ν) +tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log1p(x^2 / ν) diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 0000000..e469ebf --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,56 @@ +using StatsFuns, Test +using ChainRulesCore +using ChainRulesTestUtils +using Random + +@testset "chainrules" begin + x = exp(randn()) + y = exp(randn()) + z = logistic(randn()) + test_frule(betalogpdf, x, y, z) + test_rrule(betalogpdf, x, y, z) + + x = exp(randn()) + y = exp(randn()) + z = exp(randn()) + test_frule(gammalogpdf, x, y, z) + test_rrule(gammalogpdf, x, y, z) + + x = exp(randn()) + y = exp(randn()) + test_frule(chisqlogpdf, x, y) + test_rrule(chisqlogpdf, x, y) + + x = exp(randn()) + y = exp(randn()) + z = exp(randn()) + test_frule(fdistlogpdf, x, y, z) + test_rrule(fdistlogpdf, x, y, z) + + x = exp(randn()) + y = randn() + test_frule(tdistlogpdf, x, y) + test_rrule(tdistlogpdf, x, y) + + # use `BigFloat` to avoid Rmath implementation in finite differencing check + # (returns `NaN` for non-integer values) + n = rand(1:100) + x = BigFloat(n) + y = big(logistic(randn())) + z = BigFloat(rand(1:n)) + test_frule(binomlogpdf, x, y, z) + test_rrule(binomlogpdf, x, y, z) + + x = big(exp(randn())) + y = BigFloat(rand(1:100)) + test_frule(poislogpdf, x, y) + test_rrule(poislogpdf, x, y) + + # test special case λ = 0 + _, pb = rrule(StatsFuns.poislogpdf, 0.0, 0.0) + _, x̄1, _ = pb(1) + @test x̄1 == -1 + _, pb = rrule(StatsFuns.poislogpdf, 0.0, 1.0) + _, x̄1, _ = pb(1) + @test x̄1 == Inf +end diff --git a/test/runtests.jl b/test/runtests.jl index 88c7c4d..4ddedc2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -tests = ["rmath", "generic", "misc"] +tests = ["rmath", "generic", "misc", "chainrules"] for t in tests fp = "$t.jl"