Skip to content

Commit

Permalink
Better docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 10, 2024
1 parent 2fb8e1f commit 32b002b
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ F(\theta) = \mathbb{E}_{p(\theta)}[f(X)]

The following estimators are implemented:

- REINFORCE
- Reparametrization
- [REINFORCE](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.20)
- [Reparametrization](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.56)
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
[deps]
DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
10 changes: 7 additions & 3 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(
# Required fields
- `dist_constructor`: The constructor of the probability distribution, such that calling `dist_constructor(θ...)` generates an object corresponding to `p(θ)`. This object must satisfy:
- the [Random API](https://docs.julialang.org/en/v1/stdlib/Random/#Hooking-into-the-Random-API)
- the [DensityInterface.jl API](https://github.com/JuliaMath/DensityInterface.jl)
- `dist_constructor`: The constructor of the probability distribution.
- `f`: The function applied inside the expectation.
- `rng::AbstractRNG`: The random number generator.
- `nb_samples::Integer`: The number of Monte-Carlo samples.
The field `dist_constructor` must be a callable such that `dist_constructor(θ...)` generates an object `dist` that corresponds to `p(θ)`.
The resulting object `dist` needs to satisfy:
- the [Random API](https://docs.julialang.org/en/v1/stdlib/Random/#Hooking-into-the-Random-API) for sampling with `rand(rng, dist)`
- the [DensityInterface.jl API](https://github.com/JuliaMath/DensityInterface.jl) for loglikelihoods with `logdensityof(dist, x)`
"""
abstract type DifferentiableExpectation{threaded} end

Expand Down
54 changes: 53 additions & 1 deletion src/distribution.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
"""
FixedAtomsProbabilityDistribution
FixedAtomsProbabilityDistribution{threaded}
A probability distribution with finite support and fixed atoms.
Whenever its expectation is differentiated, only the weights are considered active, whereas the atoms are considered constant.
# Example
```jldoctest
julia> using DifferentiableExpectations, Statistics, Zygote
julia> dist = FixedAtomsProbabilityDistribution([2, 3], [0.4, 0.6]);
julia> map(abs2, dist)
FixedAtomsProbabilityDistribution{false}([4, 9], [0.4, 0.6])
julia> mean(abs2, dist)
7.0
julia> gradient(mean, abs2, dist)[2]
(atoms = nothing, weights = [4.0, 9.0])
```
# Constructor
FixedAtomsProbabilityDistribution(
atoms::Vector,
weights::Vector;
threaded=false
)
# Fields
$(TYPEDFIELDS)
Expand All @@ -27,13 +52,30 @@ struct FixedAtomsProbabilityDistribution{threaded,A,W<:Real}
end
end

function Base.show(
io::IO, dist::FixedAtomsProbabilityDistribution{threaded}
) where {threaded}
(; atoms, weights) = dist
return print(io, "FixedAtomsProbabilityDistribution{$threaded}($atoms, $weights)")
end

Base.length(dist::FixedAtomsProbabilityDistribution) = length(dist.atoms)

"""
rand(rng, dist::FixedAtomsProbabilityDistribution)
Sample from the atoms of `dist` with probability proportional to their weights.
"""
function Random.rand(rng::AbstractRNG, dist::FixedAtomsProbabilityDistribution)
(; atoms, weights) = dist
return StatsBase.sample(rng, atoms, StatsBase.Weights(weights))
end

"""
map(f, dist::FixedAtomsProbabilityDistribution)
Apply `f` to the atoms of `dist`, leave the weights unchanged.
"""
function Base.map(f, dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded}
(; atoms, weights) = dist
new_atoms = if threaded
Expand All @@ -44,6 +86,11 @@ function Base.map(f, dist::FixedAtomsProbabilityDistribution{threaded}) where {t
return FixedAtomsProbabilityDistribution(new_atoms, weights)
end

"""
mean(dist::FixedAtomsProbabilityDistribution)
Compute the expectation of `dist`, i.e. the sum of all atoms multiplied by their respective weights.
"""
function Statistics.mean(dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded}
(; atoms, weights) = dist
if threaded
Expand All @@ -53,6 +100,11 @@ function Statistics.mean(dist::FixedAtomsProbabilityDistribution{threaded}) wher
end
end

"""
mean(f, dist::FixedAtomsProbabilityDistribution)
Shortcut for `mean(map(f, dist))`.
"""
function Statistics.mean(f, dist::FixedAtomsProbabilityDistribution)
return mean(map(f, dist))
end
Expand Down
17 changes: 17 additions & 0 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@ Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`
∂F(θ) = 𝔼[f(X) ∇₂logp(X,θ)ᵀ]
```
# Example
```jldoctest
using DifferentiableExpectations, Distributions, Zygote
F = Reinforce(exp, Normal; nb_samples=10^5)
F_true(μ, σ) = mean(LogNormal(μ, σ))
μ, σ = 0.5, 1,0
∇F, ∇F_true = gradient(F, μ, σ), gradient(F_true, μ, σ)
isapprox(collect(∇F), collect(∇F_true); rtol=1e-1)
# output
true
```
# Constructor
Reinforce(
Expand Down
17 changes: 17 additions & 0 deletions src/reparametrization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`
∂F(θ) = 𝔼_q[∂f(g(Z,θ)) ∂₂g(Z,θ)ᵀ]
```
# Example
```jldoctest
using DifferentiableExpectations, Distributions, Zygote
F = Reparametrization(exp, Normal; nb_samples=10^3)
F_true(μ, σ) = mean(LogNormal(μ, σ))
μ, σ = 0.5, 1,0
∇F, ∇F_true = gradient(F, μ, σ), gradient(F_true, μ, σ)
isapprox(collect(∇F), collect(∇F_true); rtol=1e-1)
# output
true
```
# Constructor
Reparametrization(
Expand Down
16 changes: 9 additions & 7 deletions test/distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@ using Zygote
rng = StableRNG(63)

@test_throws ArgumentError FixedAtomsProbabilityDistribution(Int[], Float64[])
@test_throws DimensionMismatch FixedAtomsProbabilityDistribution([1, 2], [1.0])
@test_throws ArgumentError FixedAtomsProbabilityDistribution([1, 2], [0.5, 0.8])
@test_throws DimensionMismatch FixedAtomsProbabilityDistribution([2, 3], [1.0])
@test_throws ArgumentError FixedAtomsProbabilityDistribution([2, 3], [0.4, 0.8])

for threaded in (false, true)
dist = FixedAtomsProbabilityDistribution([2.0, 3.0], [0.3, 0.7]; threaded)
dist = FixedAtomsProbabilityDistribution([2, 3], [0.4, 0.6]; threaded)

string(dist)

@test length(dist) == 2

@test mean(dist) 2.7
@test mean(abs2, dist) 7.5
@test mean([rand(rng, dist) for _ in 1:(10^5)]) 2.7 rtol = 0.1
@test mean(abs2, [rand(rng, dist) for _ in 1:(10^5)]) 7.5 rtol = 0.1
@test mean(dist) 2.6
@test mean(abs2, dist) 7.0
@test mean([rand(rng, dist) for _ in 1:(10^5)]) 2.6 rtol = 0.1
@test mean(abs2, [rand(rng, dist) for _ in 1:(10^5)]) 7.0 rtol = 0.1

@test map(abs2, dist).weights == dist.weights
@test map(abs2, dist).atoms == [4, 9]
Expand Down
10 changes: 6 additions & 4 deletions test/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ using Zygote
exp_with_kwargs(x; correct=false) = correct ? exp(x) : sin(x)
vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct)

μ, σ = 0.5, 1.0
true_mean(μ, σ) = mean(LogNormal(μ, σ))
true_std(μ, σ) = std(LogNormal(μ, σ))
∇mean_true = gradient(true_mean, μ, σ)

@testset verbose = true "Univariate LogNormal" begin
@testset verbose = true "Threaded: $threaded" for threaded in (false, true)
@testset "$(nameof(typeof(F)))" for F in [
Expand All @@ -29,16 +34,13 @@ vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct)
threaded=threaded,
),
]
μ, σ = 2.0, 1.0
true_mean(μ, σ) = mean(LogNormal(μ, σ))
true_std(μ, σ) = std(LogNormal(μ, σ))
string(F)

@test F.dist_constructor(μ, σ) == Normal(μ, σ)
@test F(μ, σ; correct=true) true_mean(μ, σ) rtol = 0.1
@test std(samples(F, μ, σ; correct=true)) true_std(μ, σ) rtol = 0.1

∇mean_est = gradient((μ, σ) -> F(μ, σ; correct=true), μ, σ)
∇mean_true = gradient(true_mean, μ, σ)

@test ∇mean_est[1] ∇mean_true[1] rtol = 0.2
@test ∇mean_est[2] ∇mean_true[2] rtol = 0.2
Expand Down

0 comments on commit 32b002b

Please sign in to comment.