Skip to content

Commit

Permalink
Better docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 10, 2024
1 parent 32b002b commit c42cb6b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(
# Required fields
- `dist_constructor`: The constructor of the probability distribution.
- `f`: The function applied inside the expectation.
- `dist_constructor`: The constructor of the probability distribution.
- `rng::AbstractRNG`: The random number generator.
- `nb_samples::Integer`: The number of Monte-Carlo samples.
Expand Down
10 changes: 7 additions & 3 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ true
Reinforce(
f,
dist_constructor,
dist_gradlogpdf=nothing;
dist_logdensity_grad=nothing;
rng=Random.default_rng(),
nb_samples=1,
threaded=false
Expand All @@ -43,10 +43,15 @@ $(TYPEDFIELDS)
- [`DifferentiableExpectation`](@ref)
"""
struct Reinforce{threaded,F,D,G,R<:AbstractRNG} <: DifferentiableExpectation{threaded}
"function applied inside the expectation"
f::F
"constructor of the probability distribution `(θ...) -> p(θ)`"
dist_constructor::D
"either `nothing` or a parameter gradient callable `(x, θ...) -> ∇₂logp(x, θ)`"
dist_logdensity_grad::G
"random number generator"
rng::R
"number of Monte-Carlo samples"
nb_samples::Int
end

Expand Down Expand Up @@ -76,9 +81,8 @@ function dist_logdensity_grad(
) where {threaded}
(; dist_constructor, dist_logdensity_grad) = F
if !isnothing(dist_logdensity_grad)
= dist_logdensity_grad...)
= dist_logdensity_grad(x, θ...)
else
# TODO: add Distributions.gradlogpdf
_logdensity_partial(_θ...) = logdensityof(dist_constructor(_θ...), x)
l, pullback = rrule_via_ad(rc, _logdensity_partial, θ...)
= Base.tail(pullback(one(l)))
Expand Down
23 changes: 23 additions & 0 deletions src/reparametrization.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
"""
TransformedDistribution
Represent the probability distribution `p` of a random variable `X ∼ p` with a transformation `X = T(Z)` where `Z ∼ q`.
# Fields
$(TYPEDFIELDS)
"""
struct TransformedDistribution{D,T}
"the distribution `q` that gets transformed into `p`"
base_dist::D
"the transformation function `T`"
transformation::T
end

"""
reparametrize(dist)
Turn a probability distribution `p` into a [`TransformedDistribution`](@ref) `(q, T)` such that the new distribution `q` does not depend on the parameters of `p`.
These parameters are encoded (closed over) in the transformation function `T`.
"""
function reparametrize end

function reparametrize(dist::Normal{T}) where {T}
base_dist = Normal(zero(T), one(T))
μ, σ = mean(dist), std(dist)
Expand Down Expand Up @@ -54,9 +73,13 @@ $(TYPEDFIELDS)
- [`DifferentiableExpectation`](@ref)
"""
struct Reparametrization{threaded,F,D,R<:AbstractRNG} <: DifferentiableExpectation{threaded}
"function applied inside the expectation"
f::F
"constructor of the probability distribution `(θ...) -> p(θ)`"
dist_constructor::D
"random number generator"
rng::R
"number of Monte-Carlo samples"
nb_samples::Int
end

Expand Down

0 comments on commit c42cb6b

Please sign in to comment.