Skip to content

Commit

Permalink
Remove selector stuff from VarInfo tests and link/invlink (#780)
Browse files Browse the repository at this point in the history
* Remove selector stuff from varinfo tests

* Implement link and invlink for varnames rather than samplers

* Replace set_retained_vns_del_by_spl! with set_retained_vns_del!

* Make linking tests more extensive

* Remove sampler indexing from link methods (but not invlink)

* Remove indexing by samplers from invlink

* Work towards removing sampler indexing with StaticTransformation

* Fix invlink/link for TypedVarInfo and StaticTransformation

* Fix a test in models.jl

* Move some functions to utils.jl, add tests and docstrings

* Fix a docstring typo

* Various simplification to link/invlink

* Improve a docstring

* Style improvements

* Fix broken link/invlink dispatch cascade for VectorVarInfo

* Fix some more broken dispatch cascades

* Apply suggestions from code review

Co-authored-by: Xianda Sun <[email protected]>

* Remove comments that messed with docstrings

* Apply suggestions from code review

Co-authored-by: Penelope Yong <[email protected]>

* Fix issues surfaced in code review

* Simplify link/invlink arguments

* Fix a bug in unflatten VarNamedVector

* Rename VarNameCollection -> VarNameTuple

* Remove test of a removed varname_namedtuple method

* Apply suggestions from code review

Co-authored-by: Penelope Yong <[email protected]>

* Respond to review feedback

* Remove _default_sampler and a dead argument of maybe_invlink_before_eval

* Fix a typo in a comment

* Add HISTORY entry, fix one set_retained_vns_del! method

---------

Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
3 people authored Jan 30, 2025
1 parent 7140f3d commit c5f2f7a
Show file tree
Hide file tree
Showing 15 changed files with 513 additions and 501 deletions.
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

**Breaking**

### Remove indexing by samplers

This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,

- `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`.
- `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables.

### Reverse prefixing order

- For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed.
Previously, the order was that outer prefixes were applied first, then inner ones.
This version reverses that.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ set_num_produce!
increment_num_produce!
reset_num_produce!
setorder!
set_retained_vns_del_by_spl!
set_retained_vns_del!
```

```@docs
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export AbstractVarInfo,
set_num_produce!,
reset_num_produce!,
increment_num_produce!,
set_retained_vns_del_by_spl!,
set_retained_vns_del!,
is_flagged,
set_flag!,
unset_flag!,
Expand Down
130 changes: 60 additions & 70 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,117 +537,118 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl
"""
function settrans!! end

# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
# method for the case when no `vns` is provided, that would get all the keys from the
# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case
# where `vns` is provided and the one where it is not. This is because having separate
# implementations is typically much more performant, and because not all AbstractVarInfo
# types support partial linking.

"""
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in `vi` to their linked space, mutating `vi` if possible.
Transform the variables in `vi` to their linked space, using the transformation `t`,
mutating `vi` if possible.
Either transform all variables, or only ones specified in `vns`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.
See also: [`default_transformation`](@ref), [`invlink!!`](@ref).
"""
link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model)
function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, vi, SampleFromPrior(), model)
function link!!(vi::AbstractVarInfo, model::Model)
return link!!(default_transformation(model, vi), vi, model)
end
function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link!!(default_transformation(model, vi), vi, spl, model)
function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link!!(default_transformation(model, vi), vi, vns, model)
end

"""
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in `vi` to their linked space without mutating `vi`.
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
Either transform all variables, or only ones specified in `vns`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.
See also: [`default_transformation`](@ref), [`invlink`](@ref).
"""
link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model)
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link(t, deepcopy(vi), SampleFromPrior(), model)
function link(vi::AbstractVarInfo, model::Model)
return link(default_transformation(model, vi), vi, model)
end
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link(default_transformation(model, vi), deepcopy(vi), spl, model)
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link(default_transformation(model, vi), vi, vns, model)
end

"""
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform the variables in `vi` to their constrained space, using the (inverse of)
transformation `t`, mutating `vi` if possible.
Transform variables in `vi` to their constrained space, mutating `vi` if possible.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Either transform all variables, or only ones specified in `vns`.
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
not provided.
See also: [`default_transformation`](@ref), [`link!!`](@ref).
"""
invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model)
function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, vi, SampleFromPrior(), model)
function invlink!!(vi::AbstractVarInfo, model::Model)
return invlink!!(default_transformation(model, vi), vi, model)
end
function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Here we extract the `transformation` from `vi` rather than using the default one.
return invlink!!(transformation(vi), vi, spl, model)
function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink!!(default_transformation(model, vi), vi, vns, model)
end

# Vector-based ones.
function link!!(
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = inverse(t.bijector)
x = vi[spl]
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)

lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(unflatten(vi, spl, y), lp_new)
vi_new = setlogp!!(unflatten(vi, y), lp_new)
return settrans!!(vi_new, t)
end

function invlink!!(
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = t.bijector
y = vi[spl]
y = vi[:]
x, logjac = with_logabsdet_jacobian(b, y)

lp_new = getlogp(vi) + logjac
vi_new = setlogp!!(unflatten(vi, spl, x), lp_new)
vi_new = setlogp!!(unflatten(vi, x), lp_new)
return settrans!!(vi_new, NoTransformation())
end

"""
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in `vi` to their constrained space without mutating `vi`.
Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of)
transformation `t`.
Either transform all variables, or only ones specified in `vns`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
not provided.
See also: [`default_transformation`](@ref), [`link`](@ref).
"""
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink(t, vi, SampleFromPrior(), model)
function invlink(vi::AbstractVarInfo, model::Model)
return invlink(default_transformation(model, vi), vi, model)
end
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
return invlink(transformation(vi), vi, spl, model)
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink(default_transformation(model, vi), vi, vns, model)
end

"""
maybe_invlink_before_eval!!([t::Transformation,] vi, context, model)
maybe_invlink_before_eval!!([t::Transformation,] vi, model)
Return a possibly invlinked version of `vi`.
Expand Down Expand Up @@ -698,34 +699,23 @@ julia> # Now performs a single `invlink!!` before model evaluation.
-1001.4189385332047
```
"""
function maybe_invlink_before_eval!!(
vi::AbstractVarInfo, context::AbstractContext, model::Model
)
return maybe_invlink_before_eval!!(transformation(vi), vi, context, model)
function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model)
return maybe_invlink_before_eval!!(transformation(vi), vi, model)
end
function maybe_invlink_before_eval!!(
::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
)
function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model)
return vi
end
function maybe_invlink_before_eval!!(
::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
::DynamicTransformation, vi::AbstractVarInfo, model::Model
)
# `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing.
# `DynamicTransformation` is meant to _not_ do the transformation statically, hence we
# do nothing.
return vi
end
function maybe_invlink_before_eval!!(
t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
t::StaticTransformation, vi::AbstractVarInfo, model::Model
)
return invlink!!(t, vi, _default_sampler(context), model)
end

function _default_sampler(context::AbstractContext)
return _default_sampler(NodeTrait(_default_sampler, context), context)
end
_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior()
function _default_sampler(::IsParent, context::AbstractContext)
return _default_sampler(childcontext(context))
return invlink!!(t, vi, model)
end

# Utilities
Expand Down
8 changes: 4 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
# lazy `invlink`-ing of the parameters. This can be useful for
# speeding up computation. See docs for `maybe_invlink_before_eval!!`
# for more information.
maybe_invlink_before_eval!!(varinfo, context_new, model),
maybe_invlink_before_eval!!(varinfo, model),
context_new,
$(unwrap_args...),
)
Expand Down Expand Up @@ -1169,10 +1169,10 @@ end
"""
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
Expand Down
6 changes: 2 additions & 4 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn
function link!!(
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
::Model,
)
# TODO: Make sure that `spl` is respected.
b = inverse(t.bijector)
Expand All @@ -695,8 +694,7 @@ end
function invlink!!(
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
::Model,
)
# TODO: Make sure that `spl` is respected.
b = t.bijector
Expand Down
61 changes: 21 additions & 40 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,70 +81,51 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)

islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)

function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)
end

function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model)
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
# to define `getlogp(vi)`.
function link!!(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return link!!(t, deepcopy(vi), model)
end

function invlink(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return invlink!!(t, deepcopy(vi), model)
end

function maybe_invlink_before_eval!!(
vi::ThreadSafeVarInfo, context::AbstractContext, model::Model
)
function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
# Defer to the wrapped `AbstractVarInfo` object.
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)`
# hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`.
return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(
vi.varinfo, context, model
)
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the
# `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in
# the `getlogp(vi)`.
return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model)
end

# `getindex`
Expand Down Expand Up @@ -182,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
return vector_getranges(vi.varinfo, vns)
end

function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
function set_retained_vns_del!(vi::ThreadSafeVarInfo)
return set_retained_vns_del!(vi.varinfo)
end

isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
Expand Down
Loading

0 comments on commit c5f2f7a

Please sign in to comment.