From c5f2f7a14566e9c73884404f57f8fde6165c7aed Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 12:46:56 +0000 Subject: [PATCH] Remove selector stuff from VarInfo tests and link/invlink (#780) * Remove selector stuff from varinfo tests * Implement link and invlink for varnames rather than samplers * Replace set_retained_vns_del_by_spl! with set_retained_vns_del! * Make linking tests more extensive * Remove sampler indexing from link methods (but not invlink) * Remove indexing by samplers from invlink * Work towards removing sampler indexing with StaticTransformation * Fix invlink/link for TypedVarInfo and StaticTransformation * Fix a test in models.jl * Move some functions to utils.jl, add tests and docstrings * Fix a docstring typo * Various simplification to link/invlink * Improve a docstring * Style improvements * Fix broken link/invlink dispatch cascade for VectorVarInfo * Fix some more broken dispatch cascades * Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * Remove comments that messed with docstrings * Apply suggestions from code review Co-authored-by: Penelope Yong * Fix issues surfaced in code review * Simplify link/invlink arguments * Fix a bug in unflatten VarNamedVector * Rename VarNameCollection -> VarNameTuple * Remove test of a removed varname_namedtuple method * Apply suggestions from code review Co-authored-by: Penelope Yong * Respond to review feedback * Remove _default_sampler and a dead argument of maybe_invlink_before_eval * Fix a typo in a comment * Add HISTORY entry, fix one set_retained_vns_del! method --------- Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: Penelope Yong --- HISTORY.md | 9 + docs/src/api.md | 2 +- src/DynamicPPL.jl | 2 +- src/abstract_varinfo.jl | 130 +++++++------- src/model.jl | 8 +- src/simple_varinfo.jl | 6 +- src/threadsafe.jl | 61 +++---- src/transforming.jl | 20 +-- src/utils.jl | 47 ++++++ src/varinfo.jl | 364 ++++++++++++++++++++++++++-------------- src/varnamedvector.jl | 7 +- test/model.jl | 2 +- test/simple_varinfo.jl | 4 +- test/utils.jl | 23 +++ test/varinfo.jl | 329 ++++++++++-------------------------- 15 files changed, 513 insertions(+), 501 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0b9e8091b..03c564b64 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,15 @@ **Breaking** +### Remove indexing by samplers + +This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, + + - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. + - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + +### Reverse prefixing order + - For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. Previously, the order was that outer prefixes were applied first, then inner ones. This version reverses that. diff --git a/docs/src/api.md b/docs/src/api.md index 093cb06a6..36dd24250 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -304,7 +304,7 @@ set_num_produce! increment_num_produce! reset_num_produce! setorder! -set_retained_vns_del_by_spl! +set_retained_vns_del! ``` ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1cdbd94e..55e1f7e88 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -59,7 +59,7 @@ export AbstractVarInfo, set_num_produce!, reset_num_produce!, increment_num_produce!, - set_retained_vns_del_by_spl!, + set_retained_vns_del!, is_flagged, set_flag!, unset_flag!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 3f513d71d..26c4268d8 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,117 +537,118 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end +# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback +# method for the case when no `vns` is provided, that would get all the keys from the +# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case +# where `vns` is provided and the one where it is not. This is because having separate +# implementations is typically much more performant, and because not all AbstractVarInfo +# types support partial linking. + """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space, mutating `vi` if possible. -Transform the variables in `vi` to their linked space, using the transformation `t`, -mutating `vi` if possible. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ -link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, vi, SampleFromPrior(), model) +function link!!(vi::AbstractVarInfo, model::Model) + return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl, model) +function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return link!!(default_transformation(model, vi), vi, vns, model) end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space without mutating `vi`. -Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link(t, deepcopy(vi), SampleFromPrior(), model) +function link(vi::AbstractVarInfo, model::Model) + return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link(default_transformation(model, vi), deepcopy(vi), spl, model) +function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return link(default_transformation(model, vi), vi, vns, model) end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) -Transform the variables in `vi` to their constrained space, using the (inverse of) -transformation `t`, mutating `vi` if possible. +Transform variables in `vi` to their constrained space, mutating `vi` if possible. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Either transform all variables, or only ones specified in `vns`. + +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ -invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, vi, SampleFromPrior(), model) +function invlink!!(vi::AbstractVarInfo, model::Model) + return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl, model) +function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return invlink!!(default_transformation(model, vi), vi, vns, model) end # Vector-based ones. function link!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) b = inverse(t.bijector) - x = vi[spl] + x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + vi_new = setlogp!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) b = t.bijector - y = vi[spl] + y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + vi_new = setlogp!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Transform variables in `vi` to their constrained space without mutating `vi`. -Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) -transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link`](@ref). """ -invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model) -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink(t, vi, SampleFromPrior(), model) +function invlink(vi::AbstractVarInfo, model::Model) + return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - return invlink(transformation(vi), vi, spl, model) +function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return invlink(default_transformation(model, vi), vi, vns, model) end """ - maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + maybe_invlink_before_eval!!([t::Transformation,] vi, model) Return a possibly invlinked version of `vi`. @@ -698,34 +699,23 @@ julia> # Now performs a single `invlink!!` before model evaluation. -1001.4189385332047 ``` """ -function maybe_invlink_before_eval!!( - vi::AbstractVarInfo, context::AbstractContext, model::Model -) - return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model) + return maybe_invlink_before_eval!!(transformation(vi), vi, model) end -function maybe_invlink_before_eval!!( - ::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model) return vi end function maybe_invlink_before_eval!!( - ::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, model::Model ) - # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing. + # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we + # do nothing. return vi end function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + t::StaticTransformation, vi::AbstractVarInfo, model::Model ) - return invlink!!(t, vi, _default_sampler(context), model) -end - -function _default_sampler(context::AbstractContext) - return _default_sampler(NodeTrait(_default_sampler, context), context) -end -_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() -function _default_sampler(::IsParent, context::AbstractContext) - return _default_sampler(childcontext(context)) + return invlink!!(t, vi, model) end # Utilities diff --git a/src/model.jl b/src/model.jl index 6fb0b40b0..462db7397 100644 --- a/src/model.jl +++ b/src/model.jl @@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # lazy `invlink`-ing of the parameters. This can be useful for # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. - maybe_invlink_before_eval!!(varinfo, context_new, model), + maybe_invlink_before_eval!!(varinfo, model), context_new, $(unwrap_args...), ) @@ -1169,10 +1169,10 @@ end """ predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. +and the predicted values. """ function predict( rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6a84238e..57b167077 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) @@ -695,8 +694,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = t.bijector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cedb0efad..69be5dcb1 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,70 +81,51 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) +function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) end -function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) +function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model) +function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) end -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) +function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. -function link!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) +function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return link!!(t, deepcopy(vi), spl, model) +function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return invlink!!(t, deepcopy(vi), spl, model) +function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end -function maybe_invlink_before_eval!!( - vi::ThreadSafeVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` - # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!( - vi.varinfo, context, model - ) + # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the + # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogp(vi)`. + return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end # `getindex` @@ -182,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) return vector_getranges(vi.varinfo, vns) end -function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) - return set_retained_vns_del_by_spl!(vi.varinfo, spl) +function set_retained_vns_del!(vi::ThreadSafeVarInfo) + return set_retained_vns_del!(vi.varinfo) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) diff --git a/src/transforming.jl b/src/transforming.jl index 1f6c55e24..1a26d212f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,29 +91,21 @@ function dot_tilde_assume( return r, lp, vi end -function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) - return link!!(t, deepcopy(vi), spl, model) +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) - return invlink!!(t, deepcopy(vi), spl, model) +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end diff --git a/src/utils.jl b/src/utils.jl index 5fedd3039..2539b7179 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,9 @@ struct NoDefault end const NO_DEFAULT = NoDefault() +# A short-hand for a type commonly used in type signatures for VarInfo methods. +VarNameTuple = NTuple{N,VarName} where {N} + """ @addlogprob!(ex) @@ -1268,3 +1271,47 @@ _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, right::NamedTuple{()}) = left _merge(left::NamedTuple{()}, right::AbstractDict) = right + +""" + unique_syms(vns::T) where {T<:NTuple{N,VarName}} + +Return the unique symbols of the variables in `vns`. + +Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike +`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant +propagating the result, which is possible only when the argument and the return value are +`Tuple`s. +""" +@generated function unique_syms(::T) where {T<:VarNameTuple} + retval = Expr(:tuple) + syms = [first(vn.parameters) for vn in T.parameters] + for sym in unique(syms) + push!(retval.args, QuoteNode(sym)) + end + return retval +end + +""" + group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} + +Return a `NamedTuple` of the variables in `vns` grouped by symbol. + +Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to +be type stable. + +Example: +```julia +julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) +(x, y[1], x.a, z[15], y[2]) + +julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) +(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) + +julia> group_varnames_by_symbol(vns_tuple) == vns_nt +``` +""" +function group_varnames_by_symbol(vns::VarNameTuple) + syms = unique_syms(vns) + elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) + return NamedTuple{syms}(elements) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3f36cc293..09f5960c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -791,6 +791,9 @@ Returns a tuple of the unique symbols of random variables sampled in `vi`. syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) +_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) +_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) + # Get all indices of variables belonging to SampleFromPrior: # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler @@ -897,6 +900,22 @@ end return :($(exprs...),) end +""" + all_varnames_grouped_by_symbol(vi::TypedVarInfo) + +Return a `NamedTuple` of the variables in `vi` grouped by symbol. +""" +all_varnames_grouped_by_symbol(vi::TypedVarInfo) = + all_varnames_grouped_by_symbol(vi.metadata) + +@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :($f = keys(md.$f))) + end + return expr +end + # Get the index (in vals) ranges of all the vns of variables belonging to spl @inline function _getranges(vi::VarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} @@ -1150,29 +1169,50 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end +function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_grouped_by_symbol(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +end + # X -> R for all variables associated with given sampler -function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, spl, model) + has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl) + _link!(vi, vns) return vi end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, spl::AbstractSampler) +function _link!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` - vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) @@ -1183,24 +1223,41 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to link a linked vi") end end -function _link!(vi::TypedVarInfo, spl::AbstractSampler) - return _link!(vi, spl, Val(getspace(spl))) + +# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _link!(vi::TypedVarInfo, vns::VarNameTuple) + return _link!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, spaceval) + +function _link!(vi::TypedVarInfo, vns::NamedTuple) + return _link!(vi.metadata, vi, vns) +end + +""" + filter_subsumed(filter_vns, filtered_vns) + +Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. +""" +function filter_subsumed(filter_vns, filtered_vns) + return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end + @generated function _link!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) + for f in metadata_names + if !(f in vns_names) + continue + end + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns.$f, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) @@ -1210,45 +1267,65 @@ end else @warn("[DynamicPPL] attempt to link a linked vi") end - end, - ) - end + end + end, + ) end return expr end +function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_grouped_by_symbol(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, vns) + return vi +end + +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + _invlink!(vi, vns) + return vi +end + +function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) +end + # R -> X for all variables associated with given sampler -function invlink!!( - t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model -) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, spl, model) + has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl) + _invlink!(vi, vns) return vi end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) end -function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) +function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do # is just assume that `default_transformation` is the correct one if `istrans(vi)`. t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, context, model) + return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) +function _invlink!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1259,36 +1336,43 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return _invlink!(vi, spl, Val(getspace(spl))) + +# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) + return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, spaceval) + +function _invlink!(vi::TypedVarInfo, vns::NamedTuple) + return _invlink!(vi.metadata, vi, vns) end + @generated function _invlink!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, - ) + for f in metadata_names + if !(f in vns_names) + continue end + + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns.$f, f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) + end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") + end + end, + ) end return expr end @@ -1320,59 +1404,72 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end -function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(model, varinfo, spl) +function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _link(model, vi, all_varnames_grouped_by_symbol(vi)) +end + +function link(::DynamicTransformation, varinfo::VarInfo, model::Model) + return _link(model, varinfo, keys(varinfo)) +end + +function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) +end + +function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) + return _link(model, varinfo, vns) end function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end -function _link( - model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler -) +function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _link_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) +end + +# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) + return _link(model, varinfo, group_varnames_by_symbol(vns)) end -function _link(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _link_metadata_namedtuple!( +@generated function _link_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if f in vns_names push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end -function _link_metadata!!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) + +function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1444,57 +1541,76 @@ function _link_metadata!!( return metadata end +function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) +end + +function invlink(::DynamicTransformation, vi::VarInfo, model::Model) + return _invlink(model, vi, keys(vi)) +end + function invlink( - ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model ) - return _invlink(model, varinfo, spl) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) +end + +function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) + return _invlink(model, varinfo, vns) end + function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end -function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) + return _invlink(model, varinfo, group_varnames_by_symbol(vns)) +end + +function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _invlink_metadata_namedtuple!( +@generated function _invlink_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if (f in vns_names) push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end + function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns @@ -1545,7 +1661,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ end function _invlink_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns + ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns for vn in vns @@ -1966,37 +2082,35 @@ function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bo end """ - set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) + set_retained_vns_del!(vi::VarInfo) Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. """ -function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a vector - gidcs = _getidcs(vi, spl) +function set_retained_vns_del!(vi::UntypedVarInfo) + idcs = _getidcs(vi) if get_num_produce(vi) == 0 - for i in length(gidcs):-1:1 - vi.metadata.flags["del"][gidcs[i]] = true + for i in length(idcs):-1:1 + vi.metadata.flags["del"][idcs[i]] = true end else for i in 1:length(vi.orders) - if i in gidcs && vi.orders[i] > get_num_produce(vi) + if i in idcs && vi.orders[i] > get_num_produce(vi) vi.metadata.flags["del"][i] = true end end end return nothing end -function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol - gidcs = _getidcs(vi, spl) - return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) +function set_retained_vns_del!(vi::TypedVarInfo) + idcs = _getidcs(vi) + return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end -@generated function _set_retained_vns_del_by_spl!( - metadata, gidcs::NamedTuple{names}, num_produce +@generated function _set_retained_vns_del!( + metadata, idcs::NamedTuple{names}, num_produce ) where {names} expr = Expr(:block) for f in names - f_gidcs = :(gidcs.$f) + f_idcs = :(idcs.$f) f_orders = :(metadata.$f.orders) f_flags = :(metadata.$f.flags) push!( @@ -2004,12 +2118,12 @@ end quote # Set the flag for variables with symbol `f` if num_produce == 0 - for i in length($f_gidcs):-1:1 - $f_flags["del"][$f_gidcs[i]] = true + for i in length($f_idcs):-1:1 + $f_flags["del"][$f_idcs[i]] = true end else for i in 1:length($f_orders) - if i in $f_gidcs && $f_orders[i] > num_produce + if i in $f_idcs && $f_orders[i] > num_produce $f_flags["del"][i] = true end end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 565e82480..7da126321 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1068,7 +1068,12 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) new_ranges = deepcopy(vnv.ranges) recontiguify_ranges!(new_ranges) return VarNamedVector( - vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + vnv.varname_to_index, + vnv.varnames, + new_ranges, + vals, + vnv.transforms, + vnv.is_unconstrained, ) end diff --git a/test/model.jl b/test/model.jl index 118f60a40..e91de4bd2 100644 --- a/test/model.jl +++ b/test/model.jl @@ -226,7 +226,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = DynamicPPL.TestUtils.demo_dynamic_constraint() spl = SampleFromPrior() vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) - link!!(vi, spl, model) + vi = link!!(vi, model) for i in 1:10 # Sample with large variations. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 4343563eb..137c791c2 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -275,9 +275,7 @@ # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. @test !DynamicPPL.istrans( - DynamicPPL.maybe_invlink_before_eval!!( - deepcopy(vi), SamplingContext(), model - ), + DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) ) # Resulting varinfo should no longer be transformed. diff --git a/test/utils.jl b/test/utils.jl index 3f435dca4..d683f132d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,4 +48,27 @@ x = rand(dist) @test DynamicPPL.tovec(x) == vec(x.UL) end + + @testset "unique_syms" begin + vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) + @inferred DynamicPPL.unique_syms(vns) + @inferred DynamicPPL.unique_syms(()) + @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) + @test DynamicPPL.unique_syms(()) == () + end + + @testset "group_varnames_by_symbol" begin + vns_tuple = ( + @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) + ) + vns_vec = collect(vns_tuple) + vns_nt = (; + x=[@varname(x), @varname(x.a)], + y=[@varname(y[1]), @varname(y[2])], + z=[@varname(z[15])], + ) + vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] + @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) + @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt + end end diff --git a/test/varinfo.jl b/test/varinfo.jl index fce87b2f3..d689a1bf4 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,8 +1,3 @@ -# Dummy algorithm for testing -# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) -struct MyAlg{space} end -DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space - function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -19,16 +14,13 @@ function check_varinfo_keys(varinfo, vns) end end -function randr( - vi::DynamicPPL.VarInfo, - vn::VarName, - dist::Distribution, - spl::DynamicPPL.Sampler, - count::Bool=false, -) +""" +Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. +""" +function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) - push!!(vi, vn, r, dist, spl) + push!!(vi, vn, r, dist) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") @@ -37,8 +29,6 @@ function randr( DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) r else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) vi[vn] end end @@ -66,7 +56,6 @@ end tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] @test meta.orders[ind] == fmeta.orders[tind] - @test meta.gids[ind] == fmeta.gids[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -89,22 +78,6 @@ end vn2 = @varname x[1][2] @test vn2 == vn1 @test hash(vn2) == hash(vn1) - @test inspace(vn1, (:x,)) - - # Tests for `inspace` - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) function test_base!!(vi_original) vi = empty!!(vi_original) @@ -114,38 +87,31 @@ end vn = @varname x dist = Normal(0, 1) r = rand(dist) - gid = DynamicPPL.Selector() @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @test length(vi[vn]) == 1 - @test length(vi[SampleFromPrior()]) == 1 - @test vi[vn] == r - @test vi[SampleFromPrior()][1] == r vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r - @test vi[SampleFromPrior()][1] == 2 * r - vi = DynamicPPL.setindex!!(vi, [3 * r], SampleFromPrior()) - @test vi[vn] == 3 * r - @test vi[SampleFromPrior()][1] == 3 * r # TODO(mhauru) Implement these functions for other VarInfo types too. if vi isa DynamicPPL.VectorVarInfo delete!(vi, vn) @test isempty(vi) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) end vi = empty!!(vi) @test isempty(vi) - return push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) + @test ~isempty(vi) end vi = VarInfo() @@ -182,9 +148,8 @@ end vn_x = @varname x dist = Normal(0, 1) r = rand(dist) - gid = Selector() - push!!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist) # del is set by default @test !is_flagged(vi, vn_x, "del") @@ -204,35 +169,13 @@ end vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) typed_vi = TypedVarInfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 end - @testset "setgid!" begin - vi = VarInfo(DynamicPPL.Metadata()) - meta = vi.metadata - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - gid1 = Selector() - gid2 = Selector(2, :HMC) - - push!!(vi, vn, r, dist, gid1) - @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - - vi = empty!!(TypedVarInfo(vi)) - meta = vi.metadata - push!!(vi, vn, r, dist, gid1) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) - end - @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -397,10 +340,9 @@ end """ function test_setval!(model, chain; sample_idx=1, chain_idx=1) var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] + θ_old = var_info[:] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] + θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) @@ -432,13 +374,21 @@ end end @testset "link!! and invlink!!" begin - @model gdemo(x, y) = begin + @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin s ~ InverseGamma(2, 3) m ~ Uniform(0, 2) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) + x = Vector{T}(undef, length(a)) + x .~ Normal(m, sqrt(s)) + y = Vector{T}(undef, length(a)) + for i in eachindex(y) + y[i] ~ Normal(m, sqrt(s)) + end + a .~ Normal(m, sqrt(s)) + for i in eachindex(b) + b[i] ~ Normal(x[i] * y[i], sqrt(s)) + end end - model = gdemo(1.0, 2.0) + model = gdemo([1.0, 1.5], [2.0, 2.5]) # Check that instantiating the model does not perform linking vi = VarInfo() @@ -448,38 +398,55 @@ end # Check that linking and invlinking set the `trans` flag accordingly v = copy(meta.vals) - link!!(vi, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.vns) - invlink!!(vi, model) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.vns) @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values vi = TypedVarInfo(vi) meta = vi.metadata - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - link!!(vi, model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!!(vi, model) + v_x = copy(meta.x.vals) + v_y = copy(meta.y.vals) + @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - - # Transform only one variable (`s`) but not the others (`m`) - spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) - link!!(vi, spl, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - invlink!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.m.vns) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable + all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + for vn in [ + @varname(s), + @varname(m), + @varname(x), + @varname(y), + @varname(x[2]), + @varname(y[2]) + ] + target_vns = filter(x -> subsumes(vn, x), all_vns) + other_vns = filter(x -> !subsumes(vn, x), all_vns) + @test !isempty(target_vns) + @test !isempty(other_vns) + vi = link!!(vi, (vn,), model) + @test all(x -> istrans(vi, x), target_vns) + @test all(x -> !istrans(vi, x), other_vns) + vi = invlink!!(vi, (vn,), model) + @test all(x -> !istrans(vi, x), all_vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + @test meta.x.vals ≈ v_x atol = 1e-10 + @test meta.y.vals ≈ v_y atol = 1e-10 + end end @testset "istrans" begin @@ -856,73 +823,17 @@ end @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end - - # The below used to error, testing to avoid regression. - @testset "merge gids" begin - gidset_left = Set([Selector(1)]) - vi_left = VarInfo() - vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) - gidset_right = Set([Selector(2)]) - vi_right = VarInfo() - vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) - varinfo_merged = merge(vi_left, vi_right) - @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left - @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right - end - - # The below used to error, testing to avoid regression. - @testset "merge different dimensions" begin - vn = @varname(x) - vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0, Normal()) - vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) - @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] - @test merge(vi_double, vi_single)[vn] == 1.0 - end end - @testset "VarInfo with selectors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo( - model, - DynamicPPL.SampleFromPrior(), - DynamicPPL.DefaultContext(), - DynamicPPL.Metadata(), - ) - selector = DynamicPPL.Selector() - spl = Sampler(MyAlg{(:s,)}(), model, selector) - - vns = DynamicPPL.TestUtils.varnames(model) - vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) - vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) - for vn in vns_s - DynamicPPL.updategid!(varinfo, vn, spl) - end - - # Should only get the variables subsumed by `@varname(s)`. - @test varinfo[spl] == - mapreduce(Base.Fix1(DynamicPPL.getindex_internal, varinfo), vcat, vns_s) - - # `link` - varinfo_linked = DynamicPPL.link(varinfo, spl, model) - # `s` variables should be linked - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - # `m` variables should NOT be linked - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - # And `varinfo` should be unchanged - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) - - # `invlink` - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) - # `s` variables should no longer be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) - # `m` variables should still not be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) - # And `varinfo_linked` should be unchanged - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - end + # The below used to error, testing to avoid regression. + @testset "merge different dimensions" begin + vn = @varname(x) + vi_single = VarInfo() + vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_double = VarInfo() + vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] + @test merge(vi_double, vi_single)[vn] == 1.0 end @testset "sampling from linked varinfo" begin @@ -1025,25 +936,22 @@ end vi = DynamicPPL.VarInfo() dists = [Categorical([0.7, 0.3]), Normal()] - spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1051,13 +959,13 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 @@ -1065,21 +973,21 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1087,69 +995,16 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 3] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 end - - @testset "varinfo ranges" begin - @model empty_model() = x = 1 - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - - function test_varinfo!(vi) - spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = DynamicPPL.VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) - end end