Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove selector stuff from VarInfo tests and link/invlink #780

Merged
merged 33 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4dc2a72
Remove selector stuff from varinfo tests
mhauru Jan 16, 2025
9b492a3
Implement link and invlink for varnames rather than samplers
mhauru Jan 16, 2025
b508f08
Replace set_retained_vns_del_by_spl! with set_retained_vns_del!
mhauru Jan 16, 2025
b8880d1
Make linking tests more extensive
mhauru Jan 16, 2025
99a8490
Remove sampler indexing from link methods (but not invlink)
mhauru Jan 22, 2025
4a79b1f
Remove indexing by samplers from invlink
mhauru Jan 22, 2025
26a1901
Merge remote-tracking branch 'origin/master' into mhauru/remove-selec…
mhauru Jan 22, 2025
090608b
Work towards removing sampler indexing with StaticTransformation
mhauru Jan 22, 2025
4749853
Fix invlink/link for TypedVarInfo and StaticTransformation
mhauru Jan 23, 2025
e960679
Fix a test in models.jl
mhauru Jan 23, 2025
d507a53
Move some functions to utils.jl, add tests and docstrings
mhauru Jan 23, 2025
41150b5
Fix a docstring typo
mhauru Jan 23, 2025
836fb13
Merge branch 'release-0.35' into mhauru/remove-selectors-linking
mhauru Jan 23, 2025
45d1f13
Various simplification to link/invlink
mhauru Jan 23, 2025
98915c2
Improve a docstring
mhauru Jan 23, 2025
f05068d
Style improvements
mhauru Jan 23, 2025
bc4c420
Fix broken link/invlink dispatch cascade for VectorVarInfo
mhauru Jan 23, 2025
71980ba
Fix some more broken dispatch cascades
mhauru Jan 23, 2025
45562a9
Apply suggestions from code review
mhauru Jan 24, 2025
db5b835
Remove comments that messed with docstrings
mhauru Jan 24, 2025
f99effe
Apply suggestions from code review
mhauru Jan 28, 2025
56194cd
Fix issues surfaced in code review
mhauru Jan 28, 2025
c187c49
Simplify link/invlink arguments
mhauru Jan 28, 2025
86b25c5
Fix a bug in unflatten VarNamedVector
mhauru Jan 28, 2025
2a6c1bc
Rename VarNameCollection -> VarNameTuple
mhauru Jan 28, 2025
853f47e
Remove test of a removed varname_namedtuple method
mhauru Jan 28, 2025
ed80328
Apply suggestions from code review
mhauru Jan 29, 2025
d996d0c
Respond to review feedback
mhauru Jan 29, 2025
2083148
Remove _default_sampler and a dead argument of maybe_invlink_before_eval
mhauru Jan 29, 2025
39fa647
Fix a typo in a comment
mhauru Jan 29, 2025
9df364f
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
2c73de5
Add HISTORY entry, fix one set_retained_vns_del! method
mhauru Jan 30, 2025
49604e1
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ set_num_produce!
increment_num_produce!
reset_num_produce!
setorder!
set_retained_vns_del_by_spl!
set_retained_vns_del!
```

```@docs
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export AbstractVarInfo,
set_num_produce!,
reset_num_produce!,
increment_num_produce!,
set_retained_vns_del_by_spl!,
set_retained_vns_del!,
is_flagged,
set_flag!,
unset_flag!,
Expand Down
130 changes: 60 additions & 70 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,117 +537,118 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl
"""
function settrans!! end

# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
# method for the case when no `vns` is provided, that would get all the keys from the
# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case
# where `vns` is provided and the one where it is not. This is because having separate
# implementations is typically much more performant, and because not all AbstractVarInfo
# types support partial linking.

"""
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)

Transform variables in `vi` to their linked space, mutating `vi` if possible.

Transform the variables in `vi` to their linked space, using the transformation `t`,
mutating `vi` if possible.
Either transform all variables, or only ones specified in `vns`.

If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.

See also: [`default_transformation`](@ref), [`invlink!!`](@ref).
"""
link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model)
function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, vi, SampleFromPrior(), model)
function link!!(vi::AbstractVarInfo, model::Model)
return link!!(default_transformation(model, vi), vi, model)
end
function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link!!(default_transformation(model, vi), vi, spl, model)
function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link!!(default_transformation(model, vi), vi, vns, model)
end

"""
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)

Transform variables in `vi` to their linked space without mutating `vi`.

Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
Either transform all variables, or only ones specified in `vns`.

If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.

See also: [`default_transformation`](@ref), [`invlink`](@ref).
"""
link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model)
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link(t, deepcopy(vi), SampleFromPrior(), model)
function link(vi::AbstractVarInfo, model::Model)
return link(default_transformation(model, vi), vi, model)
end
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link(default_transformation(model, vi), deepcopy(vi), spl, model)
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link(default_transformation(model, vi), vi, vns, model)
end

"""
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)

Transform the variables in `vi` to their constrained space, using the (inverse of)
transformation `t`, mutating `vi` if possible.
Transform variables in `vi` to their constrained space, mutating `vi` if possible.

If `t` is not provided, `default_transformation(model, vi)` will be used.
Either transform all variables, or only ones specified in `vns`.

Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
not provided.

See also: [`default_transformation`](@ref), [`link!!`](@ref).
"""
invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model)
function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, vi, SampleFromPrior(), model)
function invlink!!(vi::AbstractVarInfo, model::Model)
return invlink!!(default_transformation(model, vi), vi, model)
end
function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Here we extract the `transformation` from `vi` rather than using the default one.
return invlink!!(transformation(vi), vi, spl, model)
function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink!!(default_transformation(model, vi), vi, vns, model)
end

# Vector-based ones.
function link!!(
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = inverse(t.bijector)
x = vi[spl]
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)

lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(unflatten(vi, spl, y), lp_new)
vi_new = setlogp!!(unflatten(vi, y), lp_new)
return settrans!!(vi_new, t)
end

function invlink!!(
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = t.bijector
y = vi[spl]
y = vi[:]
x, logjac = with_logabsdet_jacobian(b, y)

lp_new = getlogp(vi) + logjac
vi_new = setlogp!!(unflatten(vi, spl, x), lp_new)
vi_new = setlogp!!(unflatten(vi, x), lp_new)
return settrans!!(vi_new, NoTransformation())
end

"""
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)

Transform variables in `vi` to their constrained space without mutating `vi`.

Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of)
transformation `t`.
Either transform all variables, or only ones specified in `vns`.

If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
not provided.

See also: [`default_transformation`](@ref), [`link`](@ref).
"""
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink(t, vi, SampleFromPrior(), model)
function invlink(vi::AbstractVarInfo, model::Model)
return invlink(default_transformation(model, vi), vi, model)
end
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
return invlink(transformation(vi), vi, spl, model)
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink(default_transformation(model, vi), vi, vns, model)
end

"""
maybe_invlink_before_eval!!([t::Transformation,] vi, context, model)
maybe_invlink_before_eval!!([t::Transformation,] vi, model)

Return a possibly invlinked version of `vi`.

Expand Down Expand Up @@ -698,34 +699,23 @@ julia> # Now performs a single `invlink!!` before model evaluation.
-1001.4189385332047
```
"""
function maybe_invlink_before_eval!!(
vi::AbstractVarInfo, context::AbstractContext, model::Model
)
return maybe_invlink_before_eval!!(transformation(vi), vi, context, model)
function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model)
return maybe_invlink_before_eval!!(transformation(vi), vi, model)
end
function maybe_invlink_before_eval!!(
::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
)
function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model)
return vi
end
function maybe_invlink_before_eval!!(
::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
::DynamicTransformation, vi::AbstractVarInfo, model::Model
)
# `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing.
# `DynamicTransformation` is meant to _not_ do the transformation statically, hence we
# do nothing.
return vi
end
function maybe_invlink_before_eval!!(
t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
t::StaticTransformation, vi::AbstractVarInfo, model::Model
)
return invlink!!(t, vi, _default_sampler(context), model)
end

function _default_sampler(context::AbstractContext)
return _default_sampler(NodeTrait(_default_sampler, context), context)
end
_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior()
function _default_sampler(::IsParent, context::AbstractContext)
return _default_sampler(childcontext(context))
return invlink!!(t, vi, model)
end

# Utilities
Expand Down
8 changes: 4 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
# lazy `invlink`-ing of the parameters. This can be useful for
# speeding up computation. See docs for `maybe_invlink_before_eval!!`
# for more information.
maybe_invlink_before_eval!!(varinfo, context_new, model),
maybe_invlink_before_eval!!(varinfo, model),
context_new,
$(unwrap_args...),
)
Expand Down Expand Up @@ -1169,10 +1169,10 @@ end
"""
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})

Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
Expand Down
6 changes: 2 additions & 4 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn
function link!!(
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
::Model,
)
# TODO: Make sure that `spl` is respected.
b = inverse(t.bijector)
Expand All @@ -695,8 +694,7 @@ end
function invlink!!(
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
::Model,
)
# TODO: Make sure that `spl` is respected.
b = t.bijector
Expand Down
61 changes: 21 additions & 40 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,70 +81,51 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)

islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)

function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)
end

function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model)
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
# to define `getlogp(vi)`.
function link!!(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return link!!(t, deepcopy(vi), model)
end

function invlink(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return invlink!!(t, deepcopy(vi), model)
end

function maybe_invlink_before_eval!!(
vi::ThreadSafeVarInfo, context::AbstractContext, model::Model
)
function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
# Defer to the wrapped `AbstractVarInfo` object.
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)`
# hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`.
return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(
vi.varinfo, context, model
)
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the
# `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in
# the `getlogp(vi)`.
return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model)
end

# `getindex`
Expand Down Expand Up @@ -182,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
return vector_getranges(vi.varinfo, vns)
end

function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
function set_retained_vns_del!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del!(vi.varinfo, spl)
end

isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
Expand Down
Loading
Loading