diff --git a/Project.toml b/Project.toml index 4ed8a0ef1..f995d7359 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29.2" +version = "0.30" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -46,7 +46,7 @@ AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" BangBang = "0.4.1" -Bijectors = "0.13.9" +Bijectors = "0.13.18" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" diff --git a/docs/make.jl b/docs/make.jl index 73ee33631..9f170fdf7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,7 +22,7 @@ makedocs(; pages=[ "Home" => "index.md", "API" => "api.md", - "Internals" => ["internals/transformations.md"], + "Internals" => ["internals/varinfo.md", "internals/transformations.md"], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index 156b51e03..638f6f3ee 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -294,10 +294,19 @@ resetlogp!! ```@docs keys getindex -DynamicPPL.getindex_internal push!! empty!! isempty +DynamicPPL.getindex_internal +DynamicPPL.setindex_internal! +DynamicPPL.update_internal! +DynamicPPL.insert_internal! +DynamicPPL.length_internal +DynamicPPL.reset! +DynamicPPL.update! +DynamicPPL.insert! +DynamicPPL.loosen_types!! +DynamicPPL.tighten_types ``` ```@docs diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md new file mode 100644 index 000000000..4f0480b61 --- /dev/null +++ b/docs/src/internals/varinfo.md @@ -0,0 +1,302 @@ +# Design of `VarInfo` + +[`VarInfo`](@ref) is a fairly simple structure. + +```@docs; canonical=false +VarInfo +``` + +It contains + + - a `logp` field for accumulation of the log-density evaluation, and + - a `metadata` field for storing information about the realizations of the different variables. + +Representing `logp` is fairly straight-forward: we'll just use a `Real` or an array of `Real`, depending on the context. + +**Representing `metadata` is a bit trickier**. This is supposed to contain all the necessary information for each `VarName` to enable the different executions of the model + extraction of different properties of interest after execution, e.g. the realization / value corresponding to a variable `@varname(x)`. + +!!! note + + We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information, e.g. a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. + +To ensure that `VarInfo` is simple and intuitive to work with, we want `VarInfo`, and hence the underlying `metadata`, to replicate the following functionality of `Dict`: + + - `keys(::Dict)`: return all the `VarName`s present in `metadata`. + - `haskey(::Dict)`: check if a particular `VarName` is present in `metadata`. + - `getindex(::Dict, ::VarName)`: return the realization corresponding to a particular `VarName`. + - `setindex!(::Dict, val, ::VarName)`: set the realization corresponding to a particular `VarName`. + - `push!(::Dict, ::Pair)`: add a new key-value pair to the container. + - `delete!(::Dict, ::VarName)`: delete the realization corresponding to a particular `VarName`. + - `empty!(::Dict)`: delete all realizations in `metadata`. + - `merge(::Dict, ::Dict)`: merge two `metadata` structures according to similar rules as `Dict`. + +*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. One can access a vectorised version of a variable's value with the following vector-like functions: + + - `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable. + - `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables. + - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values + - `setindex_internal!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. + - `setindex_internal!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values + - `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`. + +The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. + +Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. `getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. One can also manually set a transformation by giving `setindex_internal!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. + +Finally, we want want the underlying representation used in `metadata` to have a few performance-related properties: + + 1. Type-stable when possible, but functional when not. + 2. Efficient storage and iteration when possible, but functional when not. + +The "but functional when not" is important as we want to support arbitrary models, which means that we can't always have these performance properties. + +In the following sections, we'll outline how we achieve this in [`VarInfo`](@ref). + +## Type-stability + +Ensuring type-stability is somewhat non-trivial to address since we want this to be the case even when models mix continuous (typically `Float64`) and discrete (typically `Int`) variables. + +Suppose we have an implementation of `metadata` which implements the functionality outlined in the previous section. The way we approach this in `VarInfo` is to use a `NamedTuple` with a separate `metadata` *for each distinct `Symbol` used*. For example, if we have a model of the form + +```@example varinfo-design +using DynamicPPL, Distributions, FillArrays + +@model function demo() + x ~ product_distribution(Fill(Bernoulli(0.5), 2)) + y ~ Normal(0, 1) + return nothing +end +``` + +then we construct a type-stable representation by using a `NamedTuple{(:x, :y), Tuple{Vx, Vy}}` where + + - `Vx` is a container with `eltype` `Bool`, and + - `Vy` is a container with `eltype` `Float64`. + +Since `VarName` contains the `Symbol` used in its type, something like `getindex(varinfo, @varname(x))` can be resolved to `getindex(varinfo.metadata.x, @varname(x))` at compile-time. + +For example, with the model above we have + +```@example varinfo-design +# Type-unstable `VarInfo` +varinfo_untyped = DynamicPPL.untyped_varinfo( + demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() +) +typeof(varinfo_untyped.metadata) +``` + +```@example varinfo-design +# Type-stable `VarInfo` +varinfo_typed = DynamicPPL.typed_varinfo(demo()) +typeof(varinfo_typed.metadata) +``` + +They both work as expected but one results in concrete typing and the other does not: + +```@example varinfo-design +varinfo_untyped[@varname(x)], varinfo_untyped[@varname(y)] +``` + +```@example varinfo-design +varinfo_typed[@varname(x)], varinfo_typed[@varname(y)] +``` + +Notice that the untyped `VarInfo` uses `Vector{Real}` to store the boolean entries while the typed uses `Vector{Bool}`. This is because the untyped version needs the underlying container to be able to handle both the `Bool` for `x` and the `Float64` for `y`, while the typed version can use a `Vector{Bool}` for `x` and a `Vector{Float64}` for `y` due to its usage of `NamedTuple`. + +!!! warning + + Of course, this `NamedTuple` approach is *not* necessarily going to help us in scenarios where the `Symbol` does not correspond to a unique type, e.g. + + ```julia + x[1] ~ Bernoulli(0.5) + x[2] ~ Normal(0, 1) + ``` + + In this case we'll end up with a `NamedTuple((:x,), Tuple{Vx})` where `Vx` is a container with `eltype` `Union{Bool, Float64}` or something worse. This is *not* type-stable but will still be functional. + + In practice, we rarely observe such mixing of types, therefore in DynamicPPL, and more widely in Turing.jl, we use a `NamedTuple` approach for type-stability with great success. + +!!! warning + + Another downside with such a `NamedTuple` approach is that if we have a model with lots of tilde-statements, e.g. `a ~ Normal()`, `b ~ Normal()`, ..., `z ~ Normal()` will result in a `NamedTuple` with 27 entries, potentially leading to long compilation times. + + For these scenarios it can be useful to fall back to "untyped" representations. + +Hence we obtain a "type-stable when possible"-representation by wrapping it in a `NamedTuple` and partially resolving the `getindex`, `setindex!`, etc. methods at compile-time. When type-stability is *not* desired, we can simply use a single `metadata` for all `VarName`s instead of a `NamedTuple` wrapping a collection of `metadata`s. + +## Efficient storage and iteration + +Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`DynamicPPL.VarNamedVector`](@ref): + +```@docs +DynamicPPL.VarNamedVector +``` + +In a [`DynamicPPL.VarNamedVector{<:VarName,T}`](@ref), we achieve the desiderata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. + +This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields: + + - `varnames::Vector{<:VarName}`: the `VarName`s in the order they appear in the `Vector{T}`. + - `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`. + - `transforms::Vector`: the transforms associated with each `VarName`. + +Mutating functions, e.g. `setindex_internal!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules: + + 1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc. + + 2. If `vn` is already present in `vnv`: + + 1. If `val` has the *same length* as the existing value for `vn`: replace existing value. + 2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field. + 3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`. + +This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. + +For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example: + +```julia +# Construct a `VarInfo` with types inferred from `model`. +varinfo = VarInfo(model) + +# Repeatedly sample from `model`. +for _ in 1:num_samples + rand!(rng, model, varinfo) + + # Do something with `varinfo`. + # ... +end +``` + +There are typically a few scenarios where we encounter changing representation sizes of a random variable `x`: + + 1. We're working with a transformed version `x` which is represented in a lower-dimensional space, e.g. transforming a `x ~ LKJ(2, 1)` to unconstrained `y = f(x)` takes us from 2-by-2 `Matrix{Float64}` to a 1-length `Vector{Float64}`. + 2. `x` has a random size, e.g. in a mixture model with a prior on the number of components. Here the size of `x` can vary widly between every realization of the `Model`. + +In scenario (1), we're usually *shrinking* the representation of `x`, and so we end up not making any allocations for the underlying `Vector{T}` but instead just marking the redundant part as "inactive". + +In scenario (2), we end up increasing the allocated memory for the randomly sized `x`, eventually leading to a vector that is large enough to hold realizations without needing to reallocate. But this can still lead to unnecessary memory usage, which might be undesirable. Hence one has to make a decision regarding the trade-off between memory usage and performance for the use-case at hand. + +To help with this, we have the following functions: + +```@docs +DynamicPPL.has_inactive +DynamicPPL.num_inactive +DynamicPPL.num_allocated +DynamicPPL.is_contiguous +DynamicPPL.contiguify! +``` + +For example, one might encounter the following scenario: + +```@example varinfo-design +vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) +println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") + +for i in 1:5 + x = fill(true, rand(1:100)) + DynamicPPL.update!(vnv, x, @varname(x)) + println( + "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", + ) +end +``` + +We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage: + +```@example varinfo-design +vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) +println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") + +for i in 1:5 + x = fill(true, rand(1:100)) + DynamicPPL.update!(vnv, x, @varname(x)) + if DynamicPPL.num_allocated(vnv) > 10 + DynamicPPL.contiguify!(vnv) + end + println( + "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", + ) +end +``` + +This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. + +!!! note + + Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`. + +Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field: + +```@example varinfo-design +# Type-unstable +varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] +``` + +```@example varinfo-design +# Type-stable +varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed) +varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] +``` + +If we now try to `delete!` `@varname(x)` + +```@example varinfo-design +haskey(varinfo_untyped_vnv, @varname(x)) +``` + +```@example varinfo-design +DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) +``` + +```@example varinfo-design +# `delete!` +DynamicPPL.delete!(varinfo_untyped_vnv.metadata, @varname(x)) +DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) +``` + +```@example varinfo-design +haskey(varinfo_untyped_vnv, @varname(x)) +``` + +Or insert a differently-sized value for `@varname(x)` + +```@example varinfo-design +DynamicPPL.insert!(varinfo_untyped_vnv.metadata, fill(true, 1), @varname(x)) +varinfo_untyped_vnv[@varname(x)] +``` + +```@example varinfo-design +DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) +``` + +```@example varinfo-design +DynamicPPL.update!(varinfo_untyped_vnv.metadata, fill(true, 4), @varname(x)) +varinfo_untyped_vnv[@varname(x)] +``` + +```@example varinfo-design +DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) +``` + +### Performance summary + +In the end, we have the following "rough" performance characteristics for `VarNamedVector`: + +| Method | Is blazingly fast? | +|:----------------------------------------:|:--------------------------------------------------------------------------------------------:| +| `getindex` | ${\color{green} \checkmark}$ | +| `setindex!` on a new `VarName` | ${\color{green} \checkmark}$ | +| `delete!` | ${\color{red} \times}$ | +| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | +| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | + +## Other methods + +```@docs +DynamicPPL.replace_raw_storage(::DynamicPPL.VarNamedVector, vals::AbstractVector) +``` + +```@docs; canonical=false +DynamicPPL.values_as(::DynamicPPL.VarNamedVector) +``` diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 1c6e188fb..1559467f8 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -24,4 +24,6 @@ ChainRulesCore.@non_differentiable DynamicPPL.updategid!( # No need + causes issues for some AD backends, e.g. Zygote. ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x) +ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges) + end # module diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7c7fb216d..c91fb1fe0 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -42,6 +42,65 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +""" + generated_quantities(model::Model, chain::MCMCChains.Chains) + +Execute `model` for each of the samples in `chain` and return an array of the values +returned by the `model` for each sample. + +# Examples +## General +Often you might have additional quantities computed inside the model that you want to +inspect, e.g. +```julia +@model function demo(x) + # sample and observe + θ ~ Prior() + x ~ Likelihood() + return interesting_quantity(θ, x) +end +m = demo(data) +chain = sample(m, alg, n) +# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples +# from the posterior/`chain`: +generated_quantities(m, chain) # <= results in a `Vector` of returned values + # from `interesting_quantity(θ, x)` +``` +## Concrete (and simple) +```julia +julia> using DynamicPPL, Turing + +julia> @model function demo(xs) + s ~ InverseGamma(2, 3) + m_shifted ~ Normal(10, √s) + m = m_shifted - 10 + + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + + return (m, ) + end +demo (generic function with 1 method) + +julia> model = demo(randn(10)); + +julia> chain = sample(model, MH(), 10); + +julia> generated_quantities(model, chain) +10×1 Array{Tuple{Float64},2}: + (2.1964758025119338,) + (2.1964758025119338,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.043088571494005024,) + (-0.16489786710222099,) + (-0.16489786710222099,) +``` +""" function DynamicPPL.generated_quantities( model::DynamicPPL.Model, chain_full::MCMCChains.Chains ) @@ -49,14 +108,86 @@ function DynamicPPL.generated_quantities( varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) + if DynamicPPL.supports_varname_indexing(chain) + varname_pairs = _varname_pairs_with_varname_indexing( + chain, varinfo, sample_idx, chain_idx + ) + else + varname_pairs = _varname_pairs_without_varname_indexing( + chain, varinfo, sample_idx, chain_idx + ) + end + fixed_model = DynamicPPL.fix(model, Dict(varname_pairs)) + return fixed_model() + end +end + +""" + _varname_pairs_with_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx + ) - # TODO: Some of the variables can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to `model`. - model(deepcopy(varinfo)) +Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values +from the chain. + +This implementation assumes `chain` can be indexed using variable names, and is the +preffered implementation. +""" +function _varname_pairs_with_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx +) + vns = DynamicPPL.varnames(chain) + vn_parents = Iterators.map(vns) do vn + # The call nested_setindex_maybe! is used to handle cases where vn is not + # the variable name used in the model, but rather subsumed by one. Except + # for the subsumption part, this could be + # vn => getindex_varname(chain, sample_idx, vn, chain_idx) + # TODO(mhauru) This call to nested_setindex_maybe! is unintuitive. + DynamicPPL.nested_setindex_maybe!( + varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn + ) end + varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent + vn_parent => varinfo[vn_parent] + end + return varname_pairs +end + +""" +Check which keys in `key_strings` are subsumed by `vn_string` and return the their values. + +The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and +won't catch all cases. We should get rid of this if we can. +""" +# TODO(mhauru) See docstring above. +function _vcat_subsumed_values(vn_string, values, key_strings) + indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings) + return !isempty(indices) ? reduce(vcat, values[indices]) : nothing +end + +""" + _varname_pairs_without_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx + ) + +Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values +from the chain. + +This implementation does not assume that `chain` can be indexed using variable names. It is +thus not guaranteed to work in cases where the variable names have complex subsumption +patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`. +""" +function _varname_pairs_without_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx +) + values = chain.value[sample_idx, :, chain_idx] + keys = Base.keys(chain) + keys_strings = map(string, keys) + varname_pairs = [ + vn => _vcat_subsumed_values(string(vn), values, keys_strings) for + vn in Base.keys(varinfo) + ] + return varname_pairs end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 777c770d4..a5d178125 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -177,6 +177,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("varnamedvector.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7ddd09b2e..3f513d71d 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa DynamicPPL.Metadata + md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector} true julia> values_as(vi, NamedTuple) @@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa DynamicPPL.Metadata + values_as(vi) isa Union{DynamicPPL.Metadata, Vector} true julia> values_as(vi, NamedTuple) @@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This should generally not be called explicitly, as it's only used in [`matchingvalue`](@ref) to determine the default type to use in place of type-parameters passed to the model. - + This method is considered legacy, and is likely to be deprecated in the future. """ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) @@ -363,6 +363,13 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP return eltype(T) end +""" + has_varnamedvector(varinfo::VarInfo) + +Returns `true` if `varinfo` uses `VarNamedVector` as metadata. +""" +has_varnamedvector(vi::AbstractVarInfo) = false + # TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert # the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which # might result in a `Vector{Any}`. @@ -554,7 +561,7 @@ end link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) -Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. +Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. If `t` is not provided, `default_transformation(model, vi)` will be used. @@ -573,7 +580,7 @@ end invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) -Transform the variables in `vi` to their constrained space, using the (inverse of) +Transform the variables in `vi` to their constrained space, using the (inverse of) transformation `t`, mutating `vi` if possible. If `t` is not provided, `default_transformation(model, vi)` will be used. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 13231837f..f3c5171b0 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -240,9 +240,14 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure + # if that's okay. + unset_flag!(vi, vn, "del", true) r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) + # TODO(mhauru) This should probably be call a function called setindex_internal! + # Also, if we use !! we shouldn't ignore the return value. BangBang.setindex!!(vi, f(r), vn) setorder!(vi, vn, get_num_produce(vi)) else @@ -516,7 +521,10 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vns[1], "del", true) r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] @@ -554,7 +562,10 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vns[1], "del", true) f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) for i in eachindex(vns) diff --git a/src/model.jl b/src/model.jl index 082ec3871..2a1a6db88 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1203,74 +1203,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - generated_quantities(model::Model, chain::AbstractChains) - -Execute `model` for each of the samples in `chain` and return an array of the values -returned by the `model` for each sample. - -# Examples -## General -Often you might have additional quantities computed inside the model that you want to -inspect, e.g. -```julia -@model function demo(x) - # sample and observe - θ ~ Prior() - x ~ Likelihood() - return interesting_quantity(θ, x) -end -m = demo(data) -chain = sample(m, alg, n) -# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples -# from the posterior/`chain`: -generated_quantities(m, chain) # <= results in a `Vector` of returned values - # from `interesting_quantity(θ, x)` -``` -## Concrete (and simple) -```julia -julia> using DynamicPPL, Turing - -julia> @model function demo(xs) - s ~ InverseGamma(2, 3) - m_shifted ~ Normal(10, √s) - m = m_shifted - 10 - - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - return (m, ) - end -demo (generic function with 1 method) - -julia> model = demo(randn(10)); - -julia> chain = sample(model, MH(), 10); - -julia> generated_quantities(model, chain) -10×1 Array{Tuple{Float64},2}: - (2.1964758025119338,) - (2.1964758025119338,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.043088571494005024,) - (-0.16489786710222099,) - (-0.16489786710222099,) -``` -""" -function generated_quantities(model::Model, chain::AbstractChains) - varinfo = VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - return map(iters) do (sample_idx, chain_idx) - setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - model(varinfo) - end -end - """ generated_quantities(model::Model, parameters::NamedTuple) generated_quantities(model::Model, values, keys) @@ -1297,7 +1229,7 @@ demo (generic function with 2 methods) julia> model = demo(randn(10)); -julia> parameters = (; s = 1.0, m_shifted=10); +julia> parameters = (; s = 1.0, m_shifted=10.0); julia> generated_quantities(model, parameters) (0.0,) @@ -1307,13 +1239,10 @@ julia> generated_quantities(model, values(parameters), keys(parameters)) ``` """ function generated_quantities(model::Model, parameters::NamedTuple) - varinfo = VarInfo(model) - setval_and_resample!(varinfo, values(parameters), keys(parameters)) - return model(varinfo) + fixed_model = fix(model, parameters) + return fixed_model() end function generated_quantities(model::Model, values, keys) - varinfo = VarInfo(model) - setval_and_resample!(varinfo, values, keys) - return model(varinfo) + return generated_quantities(model, NamedTuple{keys}(values)) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d8afb9cec..88f892a72 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -322,15 +322,17 @@ Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) return map(Base.Fix1(getindex, vi), vns) end -# HACK: Needed to disambiguiate. +# HACK: Needed to disambiguate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `AbstractDict` -function getindex_internal(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) - return nested_getindex(vi.values, vn) +function getindex_internal( + vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName +) + return getvalue(vi.values, vn) end Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) @@ -399,14 +401,28 @@ end function BangBang.push!!( vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, - r, + value, dist::Distribution, gidset::Set{Selector}, ) - vi.values[vn] = r + vi.values[vn] = value return vi end +function BangBang.push!!( + vi::SimpleVarInfo{<:VarNamedVector}, + vn::VarName, + value, + dist::Distribution, + gidset::Set{Selector}, +) + # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For + # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. + # Hence we need to call update!! here, which has the same semantics as push!! does for + # SimpleVarInfo. + return Accessors.@set vi.values = setindex!!(vi.values, value, vn) +end + const SimpleOrThreadSafeSimple{T,V,C} = Union{ SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } @@ -456,6 +472,8 @@ function _subset(x::NamedTuple, vns) return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms))) end +_subset(x::VarNamedVector, vns) = subset(x, vns) + # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) @@ -563,6 +581,9 @@ end function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) return NamedTuple((Symbol(k), v) for (k, v) in vi.values) end +function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} + return values_as(vi.values, T) +end """ logjoint(model::Model, θ) @@ -708,3 +729,5 @@ end function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) end + +has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils.jl b/src/test_utils.jl index 8489f2684..6199138aa 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -37,20 +37,35 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped = VarInfo() - model(vi_untyped) - vi_typed = DynamicPPL.TypedVarInfo(vi_untyped) + vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) + vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) + model(vi_untyped_metadata) + model(vi_untyped_vnv) + vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) + vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) svi_untyped = SimpleVarInfo(OrderedDict()) + svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) # SimpleVarInfo{<:Any,<:Ref} svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) + svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv))) - lp = getlogp(vi_typed) + lp = getlogp(vi_typed_metadata) varinfos = map(( - vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref + vi_untyped_metadata, + vi_untyped_vnv, + vi_typed_metadata, + vi_typed_vnv, + svi_typed, + svi_untyped, + svi_vnv, + svi_typed_ref, + svi_untyped_ref, + svi_vnv_ref, )) do vi # Set them all to the same values. DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4fbf0d124..ec890a674 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -55,6 +55,8 @@ function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end +has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) + function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) @@ -188,8 +190,10 @@ end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) -function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!(vi.varinfo, vn, flag) +function unset_flag!( + vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false +) + return unset_flag!(vi.varinfo, vn, flag, ignoreable) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) diff --git a/src/utils.jl b/src/utils.jl index f5057f4d6..bd5d365fc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -48,7 +48,7 @@ true i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. If you would like to avoid this behaviour you should check the evaluation context. It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: + For instance, in the following example the log density is not accumulated when only the log prior is computed: ```jldoctest; setup = :(using Distributions) julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); @@ -225,21 +225,82 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -# Useful transformation going from the flattened representation. -struct FromVec{Size} <: Bijectors.Bijector - size::Size +""" + UnwrapSingletonTransform(input_size::InSize) + +A transformation that unwraps a singleton array, returning a scalar. + +The `input_size` field is the expected size of the input. In practice this only determines +the number of indices, since all dimensions must be 1 for a singleton. `input_size` is used +to check the validity of the input, but also to determine the correct inverse operation. + +By default `input_size` is `(1,)`, in which case `tovec` is the inverse. +""" +struct UnwrapSingletonTransform{InSize} <: Bijectors.Bijector + input_size::InSize end -FromVec(x::Union{Real,AbstractArray}) = FromVec(size(x)) +UnwrapSingletonTransform() = UnwrapSingletonTransform((1,)) + +function (f::UnwrapSingletonTransform)(x) + if size(x) != f.input_size + throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))")) + end + return only(x) +end + +Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) = (f(x), 0) +function Bijectors.with_logabsdet_jacobian( + inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x +) + f = inv_f.orig + return (reshape([x], f.input_size), 0) +end + +""" + ReshapeTransform(input_size::InSize, output_size::OutSize) + +A `Bijector` that transforms arrays of size `input_size` to arrays of size `output_size`. + +`input_size` is not needed for the implementation of the transformation. It is only used to +check that the input is of the expected size, and to determine the correct inverse +operation. + +By default `input_size` is the vectorized version of `output_size`. In this case this +transformation is the inverse of `tovec` called on an array. +""" +struct ReshapeTransform{InSize,OutSize} <: Bijectors.Bijector + input_size::InSize + output_size::OutSize +end + +function ReshapeTransform(output_size::Tuple) + input_size = (prod(output_size),) + return ReshapeTransform(input_size, output_size) +end + +ReshapeTransform(x::AbstractArray) = ReshapeTransform(size(x)) # TODO: Should we materialize the `reshape`? -(f::FromVec)(x) = reshape(x, f.size) -(f::FromVec{Tuple{}})(x) = only(x) -# TODO: Specialize for `Tuple{<:Any}` since this correspond to a `Vector`. +function (f::ReshapeTransform)(x) + if size(x) != f.input_size + throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))")) + end + # The call to `tovec` is only needed in case `x` is a scalar. + return reshape(tovec(x), f.output_size) +end + +function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x) + f = inv_f.orig + inverse = ReshapeTransform(f.output_size, f.input_size) + return inverse(x) +end -Bijectors.with_logabsdet_jacobian(f::FromVec, x) = (f(x), 0) -# We want to use the inverse of `FromVec` so it preserves the size information. -Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:FromVec}, x) = (tovec(x), 0) +Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0) + +function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x) + return (inv_f(x), 0) +end struct ToChol <: Bijectors.Bijector uplo::Char @@ -247,22 +308,30 @@ end Bijectors.with_logabsdet_jacobian(f::ToChol, x) = (Cholesky(Matrix(x), f.uplo, 0), 0) Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = (y.UL, 0) +function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y) + return error( + "Inverse{ToChol} is only defined for Cholesky factorizations. " * + "Got a $(typeof(y)) instead.", + ) +end """ from_vec_transform(x) Return the transformation from the vector representation of `x` to original representation. """ -from_vec_transform(x::Union{Real,AbstractArray}) = from_vec_transform_for_size(size(x)) -from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ FromVec(size(C.UL)) +from_vec_transform(x::AbstractArray) = from_vec_transform_for_size(size(x)) +from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ ReshapeTransform(size(C.UL)) +from_vec_transform(::Real) = UnwrapSingletonTransform() """ from_vec_transform_for_size(sz::Tuple) -Return the transformation from the vector representation of a realization of size `sz` to original representation. +Return the transformation from the vector representation of a realization of size `sz` to +original representation. """ -from_vec_transform_for_size(sz::Tuple) = FromVec(sz) -from_vec_transform_for_size(::Tuple{()}) = FromVec(()) +from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz) +# TODO(mhauru) Is the below used? If not, this function can be removed. from_vec_transform_for_size(::Tuple{<:Any}) = identity """ @@ -272,7 +341,8 @@ Return the transformation from the vector representation of a realization from distribution `dist` to the original representation compatible with `dist`. """ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) -from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ FromVec(size(dist)) +from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() +from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) """ from_vec_transform(f, size::Tuple) @@ -300,6 +370,19 @@ function from_linked_vec_transform(dist::Distribution) return f_invlink ∘ f_vec end +# UnivariateDistributions need to be handled as a special case, because size(dist) is (), +# which makes the usual machinery think we are dealing with a 0-dim array, whereas in +# actuality we are dealing with a scalar. +# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and +# VarNamedVector takes over from Metadata. +function from_linked_vec_transform(dist::UnivariateDistribution) + f_invlink = invlink_transform(dist) + f_vec = from_vec_transform(inverse(f_invlink), size(dist)) + f_combined = f_invlink ∘ f_vec + sz = Bijectors.output_size(f_combined, size(dist)) + return UnwrapSingletonTransform(sz) ∘ f_combined +end + # Specializations that circumvent the `from_vec_transform` machinery. function from_linked_vec_transform(dist::LKJCholesky) return inverse(Bijectors.VecCholeskyBijector(dist.uplo)) @@ -854,6 +937,7 @@ end Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`. """ float_type_with_fallback(::Type) = Real +float_type_with_fallback(::Type{Union{}}) = Real float_type_with_fallback(::Type{T}) where {T<:Real} = float(T) """ diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 52ba6eb61..c5003d53a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -177,7 +177,7 @@ julia> # Approach 1: Convert back to constrained space using `invlink` and extra julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions # used in the very first model evaluation, hence the support of `y` # is not updated even though `x` has changed. - lb ≤ varinfo_invlinked[@varname(y)] ≤ ub + lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub false julia> # Approach 2: Extract realizations using `values_as_in_model`. diff --git a/src/varinfo.jl b/src/varinfo.jl index 2670397d9..8727796bc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -101,6 +101,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end +const VectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ @@ -119,6 +120,49 @@ function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) ) end +# No-op if we're already working with a `VarNamedVector`. +metadata_to_varnamedvector(vnv::VarNamedVector) = vnv +function metadata_to_varnamedvector(md::Metadata) + idcs = copy(md.idcs) + vns = copy(md.vns) + ranges = copy(md.ranges) + vals = copy(md.vals) + is_unconstrained = map(Base.Fix1(istrans, md), md.vns) + transforms = map(md.dists, is_unconstrained) do dist, trans + if trans + return from_linked_vec_transform(dist) + else + return from_vec_transform(dist) + end + end + + return VarNamedVector( + OrderedDict{eltype(keys(idcs)),Int}(idcs), + vns, + ranges, + vals, + transforms, + is_unconstrained, + ) +end + +function VectorVarInfo(vi::UntypedVarInfo) + md = metadata_to_varnamedvector(vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end + +function VectorVarInfo(vi::TypedVarInfo) + md = map(metadata_to_varnamedvector, vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end + +function has_varnamedvector(vi::VarInfo) + return vi.metadata isa VarNamedVector || + (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) +end + """ untyped_varinfo([rng, ]model[, sampler, context]) @@ -129,11 +173,14 @@ function untyped_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), + metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - varinfo = VarInfo() + varinfo = VarInfo(metadata) return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) end -function untyped_varinfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) +function untyped_varinfo( + model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... +) return untyped_varinfo(Random.default_rng(), model, args...) end @@ -149,15 +196,51 @@ function VarInfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), + metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(rng, model, sampler, context) + return typed_varinfo(rng, model, sampler, context, metadata) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) # TODO: deprecate. -unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) = VarInfo(vi, spl, x) +function unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) + md = unflatten(vi.metadata, spl, x) + return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) +end + +# The Val(getspace(spl)) is used to dispatch into the below generated function. +function unflatten(metadata::NamedTuple, spl::AbstractSampler, x::AbstractVector) + return unflatten(metadata, Val(getspace(spl)), x) +end + +@generated function unflatten( + metadata::NamedTuple{names}, ::Val{space}, x +) where {names,space} + exprs = [] + offset = :(0) + for f in names + mdf = :(metadata.$f) + if inspace(f, space) || length(space) == 0 + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = unflatten($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) + else + push!(exprs, :($f = $mdf)) + end + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end + +# For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. +function unflatten(md::Metadata, x::AbstractVector) + return replace_values(md, x) +end +function unflatten(md::Metadata, spl::AbstractSampler, x::AbstractVector) + return replace_values(md, spl, x) +end # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) @@ -256,13 +339,22 @@ end function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, varinfo.logp, varinfo.num_produce) + return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) +end + +function subset(varinfo::VectorVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) end function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} # If all the variables are using the same symbol, then we can just extract that field from the metadata. metadata = subset(getfield(varinfo.metadata, sym), vns) - return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) + return VarInfo( + NamedTuple{(sym,)}(tuple(metadata)), + deepcopy(varinfo.logp), + deepcopy(varinfo.num_produce), + ) end function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) @@ -271,7 +363,9 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) end - return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) + return VarInfo( + NamedTuple{syms}(metadatas), deepcopy(varinfo.logp), deepcopy(varinfo.num_produce) + ) end function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) @@ -338,6 +432,10 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) ) end +function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) + return merge(vnv_left, vnv_right) +end + @generated function merge_metadata( metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} ) where {names_left,names_right} @@ -528,6 +626,10 @@ Return the distribution from which `vn` was sampled in `vi`. """ getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] +# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. +function getdist(::VarNamedVector, ::VarName) + throw(ErrorException("getdist does not exist for VarNamedVector")) +end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) # TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, @@ -571,6 +673,7 @@ function getall(md::Metadata) Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) ) end +getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon()) """ setall!(vi::VarInfo, val) @@ -586,6 +689,12 @@ function _setall!(metadata::Metadata, val) metadata.vals[r] .= val[r] end end +function _setall!(vnv::VarNamedVector, val) + # TODO(mhauru) Do something more efficient here. + for i in 1:length_internal(vnv) + setindex_internal!(vnv, val[i], i) + end +end @generated function _setall!(metadata::NamedTuple{names}, val) where {names} expr = Expr(:block) start = :(1) @@ -698,7 +807,7 @@ end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) end -@inline function findinds(f_meta, s, ::Val{space}) where {space} +@inline function findinds(f_meta::Metadata, s, ::Val{space}) where {space} # Get all the idcs of the vns in `space` and that belong to the selector `s` return filter( (i) -> @@ -707,11 +816,27 @@ end 1:length(f_meta.gids), ) end -@inline function findinds(f_meta) +@inline function findinds(f_meta::Metadata) # Get all the idcs of the vns return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) end +function findinds(vnv::VarNamedVector, ::Selector, ::Val{space}) where {space} + # New Metadata objects are created with an empty list of gids, which is intrepreted as + # all Selectors applying to all variables. We assume the same behavior for + # VarNamedVector, and thus ignore the Selector argument. + if space !== () + msg = "VarNamedVector does not support selecting variables based on samplers" + throw(ErrorException(msg)) + else + return findinds(vnv) + end +end + +function findinds(vnv::VarNamedVector) + return 1:length(vnv.varnames) +end + # Get all vns of variables belonging to spl _getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) @@ -727,7 +852,7 @@ end @generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} exprs = [] for f in names - push!(exprs, :($f = metadata.$f.vns[idcs.$f])) + push!(exprs, :($f = Base.keys(metadata.$f)[idcs.$f])) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) @@ -774,6 +899,8 @@ end return results end +# TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called +# set_flag!!. """ set_flag!(vi::VarInfo, vn::VarName, flag::String) @@ -787,6 +914,15 @@ function set_flag!(md::Metadata, vn::VarName, flag::String) return md.flags[flag][getidx(md, vn)] = true end +function set_flag!(vnv::VarNamedVector, ::VarName, flag::String) + if flag == "del" + # The "del" flag is effectively always set for a VarNamedVector, so this is a no-op. + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end + return vnv +end + #### #### APIs for typed and untyped VarInfo #### @@ -795,6 +931,14 @@ end VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) +function TypedVarInfo(vi::VectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end + """ TypedVarInfo(vi::UntypedVarInfo) @@ -878,7 +1022,7 @@ end # `keys` Base.keys(md::Metadata) = md.vns -Base.keys(vi::VarInfo) = keys(vi.metadata) +Base.keys(vi::VarInfo) = Base.keys(vi.metadata) # HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly # on other methods in the codebase which requires `Vector{<:VarName}`. @@ -905,8 +1049,14 @@ end Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -function setgid!(vi::VarInfo, gid::Selector, vn::VarName) - return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +setgid!(vi::VarInfo, gid::Selector, vn::VarName) = setgid!(getmetadata(vi, vn), gid, vn) + +function setgid!(m::Metadata, gid::Selector, vn::VarName) + return push!(m.gids[getidx(m, vn)], gid) +end + +function setgid!(vnv::VarNamedVector, gid::Selector, vn::VarName) + throw(ErrorException("Calling setgid! on a VarNamedVector isn't valid.")) end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) @@ -953,18 +1103,18 @@ and parameters sampled in `vi` to 0. """ reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) -isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) -isempty(vi::TypedVarInfo) = _isempty(vi.metadata) +# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). +isempty(vi::VarInfo) = _isempty(vi.metadata) +_isempty(metadata::Metadata) = isempty(metadata.idcs) +_isempty(vnv::VarNamedVector) = isempty(vnv) @generated function _isempty(metadata::NamedTuple{names}) where {names} - expr = Expr(:&&, :true) - for f in names - push!(expr.args, :(isempty(metadata.$f.idcs))) - end - return expr + return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end # X -> R for all variables associated with given sampler function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return link(t, vi, spl, model) # Call `_link!` instead of `link!` to avoid deprecation warning. _link!(vi, spl) return vi @@ -1007,10 +1157,8 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns - dist = getdist(vi, vn) - _inner_transform!( - vi, vn, dist, internal_to_linked_internal_transform(vi, vn, dist) - ) + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, true, vn) end else @@ -1037,13 +1185,8 @@ end if ~istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - dist = getdist(vi, vn) - _inner_transform!( - vi, - vn, - dist, - internal_to_linked_internal_transform(vi, vn, dist), - ) + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, true, vn) end else @@ -1060,6 +1203,8 @@ end function invlink!!( t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model ) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return invlink(t, vi, spl, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, spl) return vi @@ -1111,10 +1256,8 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns - dist = getdist(vi, vn) - _inner_transform!( - vi, vn, dist, linked_internal_to_internal_transform(vi, vn, dist) - ) + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, false, vn) end else @@ -1141,13 +1284,8 @@ end if istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - dist = getdist(vi, vn) - _inner_transform!( - vi, - vn, - dist, - linked_internal_to_internal_transform(vi, vn, dist), - ) + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, false, vn) end else @@ -1160,11 +1298,11 @@ end return expr end -function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) - return _inner_transform!(getmetadata(vi, vn), vi, vn, dist, f) +function _inner_transform!(vi::VarInfo, vn::VarName, f) + return _inner_transform!(getmetadata(vi, vn), vi, vn, f) end -function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, dist, f) +function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) # Determine the new range. @@ -1202,10 +1340,12 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) end -function _link(model::Model, varinfo::UntypedVarInfo, spl::AbstractSampler) +function _link( + model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler +) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _link_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1229,7 +1369,7 @@ end vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!(vals.args, :(_link_metadata!(model, varinfo, metadata.$f, vns.$f))) + push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1237,7 +1377,7 @@ end return :(NamedTuple{$names}($vals)) end -function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) +function _link_metadata!!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1257,7 +1397,7 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. acclogp!!(varinfo, -logjac) - # Mark as no longer transformed. + # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. return yvec @@ -1285,6 +1425,30 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar ) end +function _link_metadata!!( + model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns +) + vns = target_vns === nothing ? keys(metadata) : target_vns + dists = extract_priors(model, varinfo) + for vn in vns + # First transform from however the variable is stored in vnv to the model + # representation. + transform_to_orig = gettransform(metadata, vn) + val_old = getindex_internal(metadata, vn) + val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old) + # Then transform from the model representation to the linked representation. + transform_from_linked = from_linked_vec_transform(dists[vn]) + transform_to_linked = inverse(transform_from_linked) + val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) + # TODO(mhauru) We are calling a !! function but ignoring the return value. + # Fix this when attending to issue #653. + acclogp!!(varinfo, -logjac1 - logjac2) + metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) + settrans!(metadata, true, vn) + end + return metadata +end + function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) @@ -1304,7 +1468,7 @@ end function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1328,7 +1492,7 @@ end vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!(vals.args, :(_invlink_metadata!(model, varinfo, metadata.$f, vns.$f))) + push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1336,13 +1500,13 @@ end return :(NamedTuple{$names}($vals)) end -function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) +function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] @@ -1385,14 +1549,31 @@ function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, targe ) end +function _invlink_metadata!!( + model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns +) + vns = target_vns === nothing ? keys(metadata) : target_vns + for vn in vns + transform = gettransform(metadata, vn) + old_val = getindex_internal(metadata, vn) + new_val, logjac = with_logabsdet_jacobian(transform, old_val) + # TODO(mhauru) We are calling a !! function but ignoring the return value. + acclogp!!(varinfo, -logjac) + new_transform = from_vec_transform(new_val) + metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) + settrans!(metadata, false, vn) + end + return metadata +end + """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) Check whether `vi` is in the transformed space for a particular sampler `spl`. -Turing's Hamiltonian samplers use the `link` and `invlink` functions from +Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of +(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. """ function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) @@ -1406,7 +1587,7 @@ end @generated function _islinked(vi, vns::NamedTuple{names}) where {names} out = [] for f in names - push!(out, :(length(vns.$f) == 0 ? false : istrans(vi, vns.$f[1]))) + push!(out, :(isempty(vns.$f) ? false : istrans(vi, vns.$f[1]))) end return Expr(:||, false, out...) end @@ -1423,9 +1604,11 @@ function nested_setindex_maybe!( nothing end end -function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) +function _nested_setindex_maybe!( + vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName +) # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = md.vns + vns = Base.keys(md) if vn in vns setindex!(vi, val, vn) return vn @@ -1436,8 +1619,7 @@ function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) i === nothing && return nothing vn_parent = vns[i] - dist = getdist(md, vn_parent) - val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here. + val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. # Split the varname into its tail optic. optic = remove_parent_optic(vn_parent, vn) # Update the value for the parent. @@ -1448,7 +1630,10 @@ end # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type -getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::VarInfo, vn::VarName) + return from_maybe_linked_internal_transform(vi, vn)(getindex_internal(vi, vn)) +end + function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_internal(vi, vn) @@ -1456,13 +1641,34 @@ function getindex(vi::VarInfo, vn::VarName, dist::Distribution) end function getindex(vi::VarInfo, vns::Vector{<:VarName}) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn) + vals = map(vn -> getindex(vi, vn), vns) + + et = eltype(vals) + # This will catch type unstable cases, where vals has mixed types. + if !isconcretetype(et) + throw(ArgumentError("All variables must have the same type.")) + end + + if et <: Vector + all_of_equal_dimension = all(x -> length(x) == length(vals[1]), vals) + if !all_of_equal_dimension + throw(ArgumentError("All variables must have the same dimension.")) + end + end + + # TODO(mhauru) I'm not very pleased with the return type varying like this, even though + # this should be type stable. + vec_vals = reduce(vcat, vals) + if et <: Vector + # The individual variables are multivariate, and thus we return the values as a + # matrix. + return reshape(vec_vals, (:, length(vns))) + else + # The individual variables are univariate, and thus we return a vector of scalars. + return vec_vals end - # HACK: I don't like this. - dist = getdist(vi, vns[1]) - return recombine(dist, vals_linked, length(vns)) end + function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn @@ -1494,6 +1700,8 @@ end return expr end +# TODO(mhauru) I think the below implementation of setindex! is a mistake. It should be +# called setindex_internal! since it directly writes to the `vals` field of the metadata. """ setindex!(vi::VarInfo, val, vn::VarName) @@ -1551,15 +1759,15 @@ end return map(vn -> vi[vn], f_vns) end -haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) +Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) """ haskey(vi::VarInfo, vn::VarName) Check whether `vn` has been sampled in `vi`. """ -haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function haskey(vi::TypedVarInfo, vn::VarName) +Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) +function Base.haskey(vi::TypedVarInfo, vn::VarName) md_haskey = map(vi.metadata) do metadata haskey(metadata, vn) end @@ -1629,6 +1837,20 @@ function BangBang.push!!( return vi end +function Base.push!(vi::VectorVarInfo, vn::VarName, val, args...) + push!(getmetadata(vi, vn), vn, val, args...) + return vi +end + +function Base.push!(vi::VectorVarInfo, pair::Pair, args...) + vn, val = pair + return push!(vi, vn, val, args...) +end + +# TODO(mhauru) push! can't be implemented in-place for TypedVarInfo if the symbol doesn't +# exist in the TypedVarInfo already. We could implement it in the cases where it it does +# exist, but that feels a bit pointless. I think we should rather rely on `push!!`. + function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 @@ -1646,6 +1868,11 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) return meta end +function Base.delete!(vi::VarInfo, vn::VarName) + delete!(getmetadata(vi, vn), vn) + return vi +end + """ setorder!(vi::VarInfo, vn::VarName, index::Int) @@ -1660,6 +1887,7 @@ function setorder!(metadata::Metadata, vn::VarName, index::Int) metadata.orders[metadata.idcs[vn]] = index return metadata end +setorder!(vnv::VarNamedVector, ::VarName, ::Int) = vnv """ getorder(vi::VarInfo, vn::VarName) @@ -1685,21 +1913,45 @@ end function is_flagged(metadata::Metadata, vn::VarName, flag::String) return metadata.flags[flag][getidx(metadata, vn)] end +function is_flagged(::VarNamedVector, ::VarName, flag::String) + if flag == "del" + return true + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end +end +# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector, +# but still having to support the interface based on Metadata too """ - unset_flag!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false Set `vn`'s value for `flag` to `false` in `vi`. + +Setting some flags for some `VarInfo` types is not possible, and by default attempting to do +so will error. If `ignorable` is set to `true` then this will silently be ignored instead. """ -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - unset_flag!(getmetadata(vi, vn), vn, flag) +function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false) + unset_flag!(getmetadata(vi, vn), vn, flag, ignorable) return vi end -function unset_flag!(metadata::Metadata, vn::VarName, flag::String) +function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false) metadata.flags[flag][getidx(metadata, vn)] = false return metadata end +function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false) + if ignorable + return vnv + end + if flag == "del" + throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector")) + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end + return vnv +end + """ set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) @@ -1804,7 +2056,7 @@ end ) where {names} updates = map(names) do n quote - for vn in metadata.$n.vns + for vn in Base.keys(metadata.$n) indices_found = kernel!(vi, vn, values, keys_strings) if indices_found !== nothing num_indices_seen += length(indices_found) @@ -1886,14 +2138,6 @@ julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1] julia> var_info[@varname(m)] # [✓] changed 100.0 -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # rerun model - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - julia> var_info[@varname(x[1])] # [✓] unchanged -0.22312984965118443 ``` @@ -1923,9 +2167,9 @@ end Set the values in `vi` to the provided values and those which are not present in `x` or `chains` to *be* resampled. -Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")` -for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these -variables will be resampled. +Note that this does *not* resample the values not provided! It will call +`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means +that the next time we call `model(vi)` these variables will be resampled. ## Note - This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. @@ -1945,7 +2189,7 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m); +julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] -0.6702516921145671 @@ -2043,6 +2287,9 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end +values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) + function values_from_metadata(md::Metadata) return ( # `copy` to avoid accidentally mutation of internal representation. @@ -2052,6 +2299,8 @@ function values_from_metadata(md::Metadata) ) end +values_from_metadata(md::VarNamedVector) = pairs(md) + # Transforming from internal representation to distribution representation. # Without `dist` argument: base on `dist` extracted from self. function from_internal_transform(vi::VarInfo, vn::VarName) @@ -2060,11 +2309,17 @@ end function from_internal_transform(md::Metadata, vn::VarName) return from_internal_transform(md, vn, getdist(md, vn)) end +function from_internal_transform(md::VarNamedVector, vn::VarName) + return gettransform(md, vn) +end # With both `vn` and `dist` arguments: base on provided `dist`. function from_internal_transform(vi::VarInfo, vn::VarName, dist) return from_internal_transform(getmetadata(vi, vn), vn, dist) end from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) +function from_internal_transform(::VarNamedVector, ::VarName, dist) + return from_vec_transform(dist) +end # Without `dist` argument: base on `dist` extracted from self. function from_linked_internal_transform(vi::VarInfo, vn::VarName) @@ -2073,6 +2328,9 @@ end function from_linked_internal_transform(md::Metadata, vn::VarName) return from_linked_internal_transform(md, vn, getdist(md, vn)) end +function from_linked_internal_transform(md::VarNamedVector, vn::VarName) + return gettransform(md, vn) +end # With both `vn` and `dist` arguments: base on provided `dist`. function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist) # Dispatch to metadata in case this alters the behavior. @@ -2081,3 +2339,6 @@ end function from_linked_internal_transform(::Metadata, ::VarName, dist) return from_linked_vec_transform(dist) end +function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) + return from_linked_vec_transform(dist) +end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl new file mode 100644 index 000000000..a5097602d --- /dev/null +++ b/src/varnamedvector.jl @@ -0,0 +1,1561 @@ +""" + VarNamedVector + +A container that stores values in a vectorised form, but indexable by variable names. + +A `VarNamedVector` can be thought of as an ordered mapping from `VarName`s to pairs of +`(internal_value, transform)`. Here `internal_value` is a vectorised value for the variable +and `transform` is a function such that `transform(internal_value)` is the "original" value +of the variable, the one that the user sees. For instance, if the variable has a matrix +value, `internal_value` could bea flattened `Vector` of its elements, and `transform` would +be a `reshape` call. + +`transform` may implement simply vectorisation, but it may do more. Most importantly, it may +implement linking, where the internal storage of a random variable is in a form where all +values in Euclidean space are valid. This is useful for sampling, because the sampler can +make changes to `internal_value` without worrying about constraints on the space of +the random variable. + +The way to access this storage format directly is through the functions `getindex_internal` +and `setindex_internal`. The `transform` argument for `setindex_internal` is optional, by +default it is either the identity, or the existing transform if a value already exists for +this `VarName`. + +`VarNamedVector` also provides a `Dict`-like interface that hides away the internal +vectorisation. This can be accessed with `getindex` and `setindex!`. `setindex!` only takes +the value, the transform is automatically set to be a simple vectorisation. The only notable +deviation from the behavior of a `Dict` is that `setindex!` will throw an error if one tries +to set a new value for a variable that lives in a different "space" than the old one (e.g. +is of a different type or size). This is because `setindex!` does not change the transform +of a variable, e.g. preserve linking, and thus the new value must be compatible with the old +transform. + +For now, a third value is in fact stored for each `VarName`: a boolean indicating whether +the variable has been transformed to unconstrained Euclidean space or not. This is only in +place temporarily due to the needs of our old Gibbs sampler. + +Internally, `VarNamedVector` stores the values of all variables in a single contiguous +vector. This makes some operations more efficient, and means that one can access the entire +contents of the internal storage quickly with `getindex_internal(vnv, :)`. The other fields +of `VarNamedVector` are mostly used to keep track of which part of the internal storage +belongs to which `VarName`. + +# Fields + +$(FIELDS) + +# Extended help + +The values for different variables are internally all stored in a single vector. For +instance, +```jldoctest varnamedvector-struct +julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal + +julia> vnv = VarNamedVector(); + +julia> setindex!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); + +julia> setindex!(vnv, reshape(1:6, (2,3)), @varname(y)); + +julia> vnv.vals +10-element Vector{Real}: + 0.0 + 0.0 + 0.0 + 0.0 + 1 + 2 + 3 + 4 + 5 + 6 +``` + +The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to +which variable. The `transforms` field stores the transformations that needed to transform +the vectorised internal storage back to its original form: + +```jldoctest varnamedvector-struct +julia> vnv.transforms[vnv.varname_to_index[@varname(y)]] == DynamicPPL.ReshapeTransform((6,), (2,3)) +true +``` + +If a variable is updated with a new value that is of a smaller dimension than the old +value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. + +```jldoctest varnamedvector-struct +julia> update!(vnv, [46.0, 48.0], @varname(x)) + +julia> vnv.vals +10-element Vector{Real}: + 46.0 + 48.0 + 0.0 + 0.0 + 1 + 2 + 3 + 4 + 5 + 6 + +julia> println(vnv.num_inactive); +OrderedDict(1 => 2) +``` + +This helps avoid unnecessary memory allocations for values that repeatedly change dimension. +The user does not have to worry about the inactive entries as long as they use functions +like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. + +```jldoctest varnamedvector-struct +julia> vnv[@varname(x)] +2-element Vector{Float64}: + 46.0 + 48.0 + +julia> getindex_internal(vnv, :) +8-element Vector{Real}: + 46.0 + 48.0 + 1 + 2 + 3 + 4 + 5 + 6 +``` +""" +struct VarNamedVector{ + K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector +} + """ + mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` + """ + varname_to_index::OrderedDict{K,Int} + + """ + vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` + """ + varnames::TVN # AbstractVector{<:VarName} + + """ + vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has + a single index or a set of contiguous indices, such that the values of `vn` can be found + at `vals[ranges[varname_to_index[vn]]]` + """ + ranges::Vector{UnitRange{Int}} + + """ + vector of values of all variables; the value(s) of `vn` is/are + `vals[ranges[varname_to_index[vn]]]` + """ + vals::TVal # AbstractVector{<:Real} + + """ + vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable + that transforms the value of `vn` back to its original space, undoing any linking and + vectorisation + """ + transforms::TTrans + + """ + vector of booleans indicating whether a variable has been transformed to unconstrained + Euclidean space or not, i.e. whether its domain is all of `ℝ^ⁿ`. Having + `is_unconstrained[varname_to_index[vn]] == false` does not necessarily mean that a + variable is constrained, but rather that it's not guaranteed to not be. + """ + is_unconstrained::BitVector + + """ + mapping from a variable index to the number of inactive entries for that variable. + Inactive entries are elements in `vals` that are not part of the value of any variable. + They arise when a variable is set to a new value with a different dimension, in-place. + Inactive entries always come after the last active entry for the given variable. + See the extended help with `??VarNamedVector` for more details. + """ + num_inactive::OrderedDict{Int,Int} + + function VarNamedVector( + varname_to_index, + varnames::TVN, + ranges, + vals::TVal, + transforms::TTrans, + is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), + num_inactive=OrderedDict{Int,Int}(), + ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} + if length(varnames) != length(ranges) || + length(varnames) != length(transforms) || + length(varnames) != length(is_unconstrained) || + length(varnames) != length(varname_to_index) + msg = ( + "Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: " * + "$(length(varnames)), ranges: " * + "$(length(ranges)), " * + "transforms: $(length(transforms)), " * + "is_unconstrained: $(length(is_unconstrained)), " * + "varname_to_index: $(length(varname_to_index))." + ) + throw(ArgumentError(msg)) + end + + num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) + if num_vals != length(vals) + msg = ( + "The total number of elements in `vals` ($(length(vals))) does not match " * + "the sum of the lengths of the ranges and the number of inactive entries " * + "($(num_vals))." + ) + throw(ArgumentError(msg)) + end + + if Set(values(varname_to_index)) != Set(axes(varnames, 1)) + msg = ( + "The set of values of `varname_to_index` is not the set of valid indices " * + "for `varnames`." + ) + throw(ArgumentError(msg)) + end + + if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) + msg = ( + "The keys of `num_inactive` are not a subset of the values of " * + "`varname_to_index`." + ) + throw(ArgumentError(msg)) + end + + # Check that the varnames don't overlap. The time cost is quadratic in number of + # variables. If this ever becomes an issue, we should be able to go down to at least + # N log N by sorting based on subsumes-order. + for vn1 in keys(varname_to_index) + for vn2 in keys(varname_to_index) + vn1 === vn2 && continue + if subsumes(vn1, vn2) + msg = ( + "Variables in a VarNamedVector should not subsume each other, " * + "but $vn1 subsumes $vn2, i.e. $vn2 describes a subrange of $vn1." + ) + throw(ArgumentError(msg)) + end + end + end + + # We could also have a test to check that the ranges don't overlap, but that sounds + # unlikely to occur, and implementing it in linear time would require a tiny bit of + # thought. + + return new{K,V,TVN,TVal,TTrans}( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained, + num_inactive, + ) + end +end + +function VarNamedVector{K,V}() where {K,V} + return VarNamedVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]) +end + +# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the +# transform vector type above could then be Union{}[]. This would allow expanding the +# VarName and element types only as necessary, which would help keep them concrete. However, +# making that change here opens some other cans of worms related to how VarInfo uses +# BangBang, that I don't want to deal with right now. +VarNamedVector() = VarNamedVector{VarName,Real}() +VarNamedVector(xs::Pair...) = VarNamedVector(OrderedDict(xs...)) +VarNamedVector(x::AbstractDict) = VarNamedVector(keys(x), values(x)) +function VarNamedVector(varnames, vals) + return VarNamedVector(collect_maybe(varnames), collect_maybe(vals)) +end +function VarNamedVector( + varnames::AbstractVector, + orig_vals::AbstractVector, + transforms=fill(identity, length(varnames)), +) + # Convert `vals` into a vector of vectors. + vals_vecs = map(tovec, orig_vals) + transforms = map( + (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals + ) + # TODO: Is this really the way to do this? + if !(eltype(varnames) <: VarName) + varnames = convert(Vector{VarName}, varnames) + end + varname_to_index = OrderedDict{eltype(varnames),Int}( + vn => i for (i, vn) in enumerate(varnames) + ) + vals = reduce(vcat, vals_vecs) + # Make the ranges. + ranges = Vector{UnitRange{Int}}() + offset = 0 + for x in vals_vecs + r = (offset + 1):(offset + length(x)) + push!(ranges, r) + offset = r[end] + end + + return VarNamedVector(varname_to_index, varnames, ranges, vals, transforms) +end + +function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) + return vnv_left.varname_to_index == vnv_right.varname_to_index && + vnv_left.varnames == vnv_right.varnames && + vnv_left.ranges == vnv_right.ranges && + vnv_left.vals == vnv_right.vals && + vnv_left.transforms == vnv_right.transforms && + vnv_left.is_unconstrained == vnv_right.is_unconstrained && + vnv_left.num_inactive == vnv_right.num_inactive +end + +getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] + +getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] +getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) + +gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] +gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) + +# TODO(mhauru) Eventually I would like to rename the istrans function to is_unconstrained, +# but that's significantly breaking. +""" + istrans(vnv::VarNamedVector, vn::VarName) + +Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain +is all of Euclidean space. +""" +istrans(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] + +""" + settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + +Set the value for whether `vn` is guaranteed to have been transformed so that all of +Euclidean space is its domain. +""" +function settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val +end + +function settrans!!(vnv::VarNamedVector, val::Bool, vn::VarName) + settrans!(vnv, val, vn) + return vnv +end + +""" + has_inactive(vnv::VarNamedVector) + +Returns `true` if `vnv` has inactive entries. + +See also: [`num_inactive`](@ref) +""" +has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive) + +""" + num_inactive(vnv::VarNamedVector) + +Return the number of inactive entries in `vnv`. + +See also: [`has_inactive`](@ref), [`num_allocated`](@ref) +""" +num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive)) + +""" + num_inactive(vnv::VarNamedVector, vn::VarName) + +Returns the number of inactive entries for `vn` in `vnv`. +""" +num_inactive(vnv::VarNamedVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) +num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0) + +""" + num_allocated(vnv::VarNamedVector) + num_allocated(vnv::VarNamedVector[, vn::VarName]) + num_allocated(vnv::VarNamedVector[, idx::Int]) + +Return the number of allocated entries in `vnv`, both active and inactive. + +If either a `VarName` or an `Int` index is specified, only count entries allocated for that +variable. + +Allocated entries take up memory in `vnv.vals`, but, if inactive, may not currently hold any +meaningful data. One can remove them with [`contiguify!`](@ref), but doing so may cause more +memory allocations in the future if variables change dimension. +""" +num_allocated(vnv::VarNamedVector) = length(vnv.vals) +num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) +function num_allocated(vnv::VarNamedVector, idx::Int) + return length(getrange(vnv, idx)) + num_inactive(vnv, idx) +end + +# Dictionary interface. +Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames) +Base.length(vnv::VarNamedVector) = length(vnv.varnames) +Base.keys(vnv::VarNamedVector) = vnv.varnames +Base.values(vnv::VarNamedVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) +Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) +Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) + +# Vector-like interface. +Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals) + +""" + length_internal(vnv::VarNamedVector) + +Return the length of the internal storage vector of `vnv`, ignoring inactive entries. +""" +function length_internal(vnv::VarNamedVector) + if !has_inactive(vnv) + return length(vnv.vals) + else + return sum(length, vnv.ranges) + end +end + +# Getting and setting values + +function Base.getindex(vnv::VarNamedVector, vn::VarName) + x = getindex_internal(vnv, vn) + f = gettransform(vnv, vn) + return f(x) +end + +""" + find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) + +Find the first range in `ranges` that contains `x`. + +Throw an `ArgumentError` if `x` is not in any of the ranges. +""" +function find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) + # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` + # for a more efficient approach. + range_idx = findfirst(Base.Fix1(∈, x), ranges) + + # If we're out of bounds, we raise an error. + if range_idx === nothing + throw(ArgumentError("Value $x is not in any of the ranges.")) + end + + return range_idx +end + +""" + adjusted_ranges(vnv::VarNamedVector) + +Return what `vnv.ranges` would be if there were no inactive entries. +""" +function adjusted_ranges(vnv::VarNamedVector) + # Every range following inactive entries needs to be shifted. + offset = 0 + ranges_adj = similar(vnv.ranges) + for (idx, r) in enumerate(vnv.ranges) + # Remove the `offset` in `r` due to inactive entries. + ranges_adj[idx] = r .- offset + # Update `offset`. + offset += get(vnv.num_inactive, idx, 0) + end + + return ranges_adj +end + +""" + index_to_vals_index(vnv::VarNamedVector, i::Int) + +Convert an integer index that ignores inactive entries to an index that accounts for them. + +This is needed when the user wants to index `vnv` like a vector, but shouldn't have to care +about inactive entries in `vnv.vals`. +""" +function index_to_vals_index(vnv::VarNamedVector, i::Int) + # If we don't have any inactive entries, there's nothing to do. + has_inactive(vnv) || return i + + # Get the adjusted ranges. + ranges_adj = adjusted_ranges(vnv) + # Determine the adjusted range that the index corresponds to. + r_idx = find_containing_range(ranges_adj, i) + r = vnv.ranges[r_idx] + # Determine how much of the index `i` is used to get to this range. + i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) + # Use remainder to index into `r`. + i_remainder = i - i_used + return r[i_remainder] +end + +""" + getindex_internal(vnv::VarNamedVector, vn::VarName) + +Like `getindex`, but returns the values as they are stored in `vnv`, without transforming. +""" +getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] + +""" + getindex_internal(vnv::VarNamedVector, i::Int) + +Gets the `i`th element of the internal storage vector, ignoring inactive entries. +""" +getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] + +function getindex_internal(vnv::VarNamedVector, ::Colon) + return if has_inactive(vnv) + mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) + else + vnv.vals + end +end + +# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs +# sampler. +function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) + throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) +end + +function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) + if haskey(vnv, vn) + return update!(vnv, val, vn) + else + return insert!(vnv, val, vn) + end +end + +""" + reset!(vnv::VarNamedVector, val, vn::VarName) + +Reset the value of `vn` in `vnv` to `val`. + +This differs from `setindex!` in that it will always change the transform of the variable +to be the default vectorisation transform. This undoes any possible linking. + +# Examples + +```jldoctest varnamedvector-reset +julia> using DynamicPPL: VarNamedVector, @varname, reset! + +julia> vnv = VarNamedVector(); + +julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); + +julia> setindex!(vnv, 2.0, @varname(x)) +ERROR: An error occurred while assigning the value 2.0 to variable x. If you are changing the type or size of a variable you'll need to call reset! +[...] + +julia> reset!(vnv, 2.0, @varname(x)); + +julia> vnv[@varname(x)] +2.0 +``` +""" +function reset!(vnv::VarNamedVector, val, vn::VarName) + f = from_vec_transform(val) + retval = setindex_internal!(vnv, tovec(val), vn, f) + settrans!(vnv, false, vn) + return retval +end + +""" + update!(vnv::VarNamedVector, val, vn::VarName) + +Update the value of `vn` in `vnv` to `val`. + +Like `setindex!`, but errors if the key `vn` doesn't exist. +""" +function update!(vnv::VarNamedVector, val, vn::VarName) + if !haskey(vnv, vn) + throw(KeyError(vn)) + end + f = inverse(gettransform(vnv, vn)) + internal_val = try + f(val) + catch + error( + "An error occurred while assigning the value $val to variable $vn. " * + "If you are changing the type or size of a variable you'll need to call " * + "reset!", + ) + end + return setindex_internal!(vnv, internal_val, vn) +end + +""" + insert!(vnv::VarNamedVector, val, vn::VarName) + +Add a variable with given value to `vnv`. + +Like `setindex!`, but errors if the key `vn` already exists. +""" +function Base.insert!(vnv::VarNamedVector, val, vn::VarName) + if haskey(vnv, vn) + throw("Variable $vn already exists in VarNamedVector.") + end + return reset!(vnv, val, vn) +end + +""" + push!(vnv::VarNamedVector, pair::Pair) + +Add a variable with given value to `vnv`. Pair should be a `VarName` and a value. +""" +function Base.push!(vnv::VarNamedVector, pair::Pair) + vn, val = pair + # TODO(mhauru) Or should this rather call `reset!`? It would be more inline with what + # Dict does, but could also cause confusion. + return setindex!(vnv, val, vn) +end + +""" + setindex_internal!(vnv::VarNamedVector, val, i::Int) + +Sets the `i`th element of the internal storage vector, ignoring inactive entries. +""" +function setindex_internal!(vnv::VarNamedVector, val, i::Int) + return vnv.vals[index_to_vals_index(vnv, i)] = val +end + +""" + setindex_internal!(vnv::VarNamedVector, val, vn::VarName[, transform]) + +Like `setindex!`, but sets the values as they are stored internally in `vnv`. + +Optionally can set the transformation, such that `transform(val)` is the original value of +the variable. By default, the transform is the identity if creating a new entry in `vnv`, or +the existing transform if updating an existing entry. +""" +function setindex_internal!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) + if haskey(vnv, vn) + return update_internal!(vnv, val, vn, transform) + else + return insert_internal!(vnv, val, vn, transform) + end +end + +""" + insert_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName[, transform]) + +Add a variable with given value to `vnv`. + +Like `setindex_internal!`, but errors if the key `vn` already exists. + +`transform` should be a function that converts `val` to the original representation. By +default it's `identity`. +""" +function insert_internal!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) + if transform === nothing + transform = identity + end + haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) + # NOTE: We need to compute the `nextrange` BEFORE we start mutating the underlying + # storage. + r_new = nextrange(vnv, val) + vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 + push!(vnv.varnames, vn) + push!(vnv.ranges, r_new) + append!(vnv.vals, val) + push!(vnv.transforms, transform) + push!(vnv.is_unconstrained, false) + return nothing +end + +""" + update_internal!(vnv::VarNamedVector, vn::VarName, val::AbstractVector[, transform]) + +Update an existing entry for `vn` in `vnv` with the value `val`. + +Like `setindex_internal!`, but errors if the key `vn` doesn't exist. + +`transform` should be a function that converts `val` to the original representation. By +default it's the same as the old transform for `vn`. +""" +function update_internal!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) + # Here we update an existing entry. + if !haskey(vnv, vn) + throw(KeyError(vn)) + end + idx = getidx(vnv, vn) + # Extract the old range. + r_old = getrange(vnv, idx) + start_old, end_old = first(r_old), last(r_old) + n_old = length(r_old) + # Compute the new range. + n_new = length(val) + start_new = start_old + end_new = start_old + n_new - 1 + r_new = start_new:end_new + + #= + Suppose we currently have the following: + + | x | x | o | o | o | y | y | y | <- Current entries + + where 'O' denotes an inactive entry, and we're going to + update the variable `x` to be of size `k` instead of 2. + + We then have a few different scenarios: + 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. + E.g. if `k = 7`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | x | x | x | x | y | y | y | <- New entries + + 2. `k = 5`: All inactive entries become active. + Then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | x | x | y | y | y | <- New entries + + 3. `k < 5`: Some inactive entries become active, some remain inactive. + E.g. if `k = 3`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | o | o | y | y | y | <- New entries + + 4. `k = 2`: No inactive entries become active. + Then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | o | o | o | y | y | y | <- New entries + + 5. `k < 2`: More entries become inactive. + E.g. if `k = 1`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | o | o | o | o | y | y | y | <- New entries + =# + + # Compute the allocated space for `vn`. + had_inactive = haskey(vnv.num_inactive, idx) + n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old + + if n_new > n_allocated + # Then we need to grow the underlying vector. + n_extra = n_new - n_allocated + # Allocate. + resize!(vnv.vals, length(vnv.vals) + n_extra) + # Shift current values. + shift_right!(vnv.vals, end_old + 1, n_extra) + # No more inactive entries. + had_inactive && delete!(vnv.num_inactive, idx) + # Update the ranges for all variables after this one. + shift_subsequent_ranges_by!(vnv, idx, n_extra) + elseif n_new == n_allocated + # => No more inactive entries. + had_inactive && delete!(vnv.num_inactive, idx) + else + # `n_new < n_allocated` + # => Need to update the number of inactive entries. + vnv.num_inactive[idx] = n_allocated - n_new + end + + # Update the range for this variable. + vnv.ranges[idx] = r_new + # Update the value. + vnv.vals[r_new] = val + if transform !== nothing + # Update the transform. + vnv.transforms[idx] = transform + end + + # TODO: Should we maybe sweep over inactive ranges and re-contiguify + # if the total number of inactive elements is "large" in some sense? + + return nothing +end + +# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# Remove this method as soon as possible. +function BangBang.push!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + +# BangBang versions of the above functions. +# The only difference is that update_internal!! and insert_internal!! check whether the +# container types of the VarNamedVector vector need to be expanded to accommodate the new +# values. If so, they create a new instance, otherwise they mutate in place. All the others +# functions, e.g. setindex!!, setindex_internal!!, etc., are carbon copies of the ! versions +# with every ! call replaced with a !! call. + +""" + loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) + +Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. + +If `KNew` is a subtype of `K` and `TransNew` is a subtype of the element type of the +`TTrans` then this is a no-op and `vnv` is returned as is. Otherwise a new `VarNamedVector` +is returned with the same data but more abstract types, so that variables of type `KNew` and +transformations of type `TransNew` can be pushed to it. Some of the underlying storage is +shared between `vnv` and the return value, and thus mutating one may affect the other. + +# See also +[`tighten_types`](@ref) + +# Examples + +```jldoctest varnamedvector-loosen-types +julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! + +julia> vnv = VarNamedVector(@varname(x) => [1.0]); + +julia> y_trans(x) = reshape(x, (2, 2)); + +julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) +ERROR: MethodError: Cannot `convert` an object of type +[...] + +julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), typeof(y_trans)); + +julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) + +julia> vnv_loose[@varname(y)] +2×2 Matrix{Float64}: + 1.0 3.0 + 2.0 4.0 +``` +""" +function loosen_types!!( + vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} +) where {KNew,TransNew} + K = eltype(vnv.varnames) + Trans = eltype(vnv.transforms) + if KNew <: K && TransNew <: Trans + return vnv + else + vn_type = promote_type(K, KNew) + transform_type = promote_type(Trans, TransNew) + return VarNamedVector( + OrderedDict{vn_type,Int}(vnv.varname_to_index), + Vector{vn_type}(vnv.varnames), + vnv.ranges, + vnv.vals, + Vector{transform_type}(vnv.transforms), + vnv.is_unconstrained, + vnv.num_inactive, + ) + end +end + +""" + tighten_types(vnv::VarNamedVector) + +Return a copy of `vnv` with the most concrete types possible. + +For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the +transforms are actually identity transformations, this function will return a new +`VarNamedVector` with the transforms vector having eltype `typeof(identity)`. + +This is a lot like the reverse of [`loosen_types!!`](@ref), but with two notable +differences: Unlike `loosen_types!!`, this function does not mutate `vnv`; it also changes +not only the key and transform eltypes, but also the values eltype. + +# See also +[`loosen_types!!`](@ref) + +# Examples + +```jldoctest varnamedvector-tighten-types +julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! + +julia> vnv = VarNamedVector(); + +julia> setindex!(vnv, [23], @varname(x)) + +julia> eltype(vnv) +Real + +julia> vnv.transforms +1-element Vector{Any}: + identity (generic function with 1 method) + +julia> vnv_tight = DynamicPPL.tighten_types(vnv); + +julia> eltype(vnv_tight) == Int +true + +julia> vnv_tight.transforms +1-element Vector{typeof(identity)}: + identity (generic function with 1 method) +``` +""" +function tighten_types(vnv::VarNamedVector) + return VarNamedVector( + OrderedDict(vnv.varname_to_index...), + map(identity, vnv.varnames), + copy(vnv.ranges), + map(identity, vnv.vals), + map(identity, vnv.transforms), + copy(vnv.is_unconstrained), + copy(vnv.num_inactive), + ) +end + +function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) + if haskey(vnv, vn) + return update!!(vnv, val, vn) + else + return insert!!(vnv, val, vn) + end +end + +function reset!!(vnv::VarNamedVector, val, vn::VarName) + f = from_vec_transform(val) + vnv = setindex_internal!!(vnv, tovec(val), vn, f) + vnv = settrans!!(vnv, false, vn) + return vnv +end + +function update!!(vnv::VarNamedVector, val, vn::VarName) + if !haskey(vnv, vn) + throw(KeyError(vn)) + end + f = inverse(gettransform(vnv, vn)) + internal_val = try + f(val) + catch + error( + "An error occurred while assigning the value $val to variable $vn. " * + "If you are changing the type or size of a variable you'll need to either " * + "`delete!` it first or use `setindex_internal!`", + ) + end + return setindex_internal!!(vnv, internal_val, vn) +end + +function insert!!(vnv::VarNamedVector, val, vn::VarName) + if haskey(vnv, vn) + throw("Variable $vn already exists in VarNamedVector.") + end + return reset!!(vnv, val, vn) +end + +function setindex_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) + if haskey(vnv, vn) + return update_internal!!(vnv, val, vn, transform) + else + return insert_internal!!(vnv, val, vn, transform) + end +end + +function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) + if transform === nothing + transform = identity + end + vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) + insert_internal!(vnv, val, vn, transform) + return vnv +end + +function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) + transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform + vnv = loosen_types!!(vnv, typeof(vn), typeof(transform_resolved)) + update_internal!(vnv, val, vn, transform) + return vnv +end + +function BangBang.push!!(vnv::VarNamedVector, pair::Pair) + vn, val = pair + return setindex!!(vnv, val, vn) +end + +# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# Remove this method as soon as possible. +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) + f = from_vec_transform(dist) + return setindex_internal!!(vnv, tovec(val), vn, f) +end + +function Base.empty!(vnv::VarNamedVector) + # TODO: Or should the semantics be different, e.g. keeping `varnames`? + empty!(vnv.varname_to_index) + empty!(vnv.varnames) + empty!(vnv.ranges) + empty!(vnv.vals) + empty!(vnv.transforms) + empty!(vnv.is_unconstrained) + empty!(vnv.num_inactive) + return nothing +end +BangBang.empty!!(vnv::VarNamedVector) = (empty!(vnv); return vnv) + +""" + replace_raw_storage(vnv::VarNamedVector, vals::AbstractVector) + +Replace the values in `vnv` with `vals`, as they are stored internally. + +This is useful when we want to update the entire underlying vector of values in one go or if +we want to change the how the values are stored, e.g. alter the `eltype`. + +!!! warning + This replaces the raw underlying values, and so care should be taken when using this + function. For example, if `vnv` has any inactive entries, then the provided `vals` + should also contain the inactive entries to avoid unexpected behavior. + +# Examples + +```jldoctest varnamedvector-replace-raw-storage +julia> using DynamicPPL: VarNamedVector, replace_raw_storage + +julia> vnv = VarNamedVector(@varname(x) => [1.0]); + +julia> replace_raw_storage(vnv, [2.0])[@varname(x)] == [2.0] +true +``` + +This is also useful when we want to differentiate wrt. the values using automatic +differentiation, e.g. ForwardDiff.jl. + +```jldoctest varnamedvector-replace-raw-storage +julia> using ForwardDiff: ForwardDiff + +julia> f(x) = sum(abs2, replace_raw_storage(vnv, x)[@varname(x)]) +f (generic function with 1 method) + +julia> ForwardDiff.gradient(f, [1.0]) +1-element Vector{Float64}: + 2.0 +``` +""" +replace_raw_storage(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals + +# TODO(mhauru) The space argument is used by the old Gibbs sampler. To be removed. +function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {space} + if length(space) > 0 + msg = "Selecting values in a VarNamedVector with a space is not supported." + throw(ArgumentError(msg)) + end + return replace_raw_storage(vnv, vals) +end + +""" + unflatten(vnv::VarNamedVector, vals::AbstractVector) + +Return a new instance of `vnv` with the values of `vals` assigned to the variables. + +This assumes that `vals` have been transformed by the same transformations that that the +values in `vnv` have been transformed by. However, unlike [`replace_raw_storage`](@ref), +`unflatten` does account for inactive entries in `vnv`, so that the user does not have to +care about them. + +This is in a sense the reverse operation of `vnv[:]`. + +Unflatten recontiguifies the internal storage, getting rid of any inactive entries. + +# Examples + +```jldoctest varnamedvector-unflatten +julia> using DynamicPPL: VarNamedVector, unflatten + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); + +julia> unflatten(vnv, vnv[:]) == vnv +true +""" +function unflatten(vnv::VarNamedVector, vals::AbstractVector) + new_ranges = deepcopy(vnv.ranges) + recontiguify_ranges!(new_ranges) + return VarNamedVector( + vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + ) +end + +# TODO(mhauru) To be removed once the old Gibbs sampler is removed. +function unflatten(vnv::VarNamedVector, spl::AbstractSampler, vals::AbstractVector) + if length(getspace(spl)) > 0 + msg = "Selecting values in a VarNamedVector with a space is not supported." + throw(ArgumentError(msg)) + end + return unflatten(vnv, vals) +end + +function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) + # Return early if possible. + isempty(left_vnv) && return deepcopy(right_vnv) + isempty(right_vnv) && return deepcopy(left_vnv) + + # Determine varnames. + vns_left = left_vnv.varnames + vns_right = right_vnv.varnames + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(left_vnv.vals) + T_right = eltype(right_vnv.vals) + T = promote_type(T_left, T_right) + + # Determine `eltype` of `varnames`. + V_left = eltype(left_vnv.varnames) + V_right = eltype(right_vnv.varnames) + V = promote_type(V_left, V_right) + if !(V <: VarName) + V = VarName + end + + # Determine `eltype` of `transforms`. + F_left = eltype(left_vnv.transforms) + F_right = eltype(right_vnv.transforms) + F = promote_type(F_left, F_right) + + # Allocate. + varname_to_index = OrderedDict{V,Int}() + ranges = UnitRange{Int}[] + vals = T[] + transforms = F[] + is_unconstrained = BitVector(undef, length(vns_both)) + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + varname_to_index[vn] = idx + # Extract the necessary information from `left` or `right`. + if vn in vns_left && !(vn in vns_right) + # `vn` is only in `left`. + val = getindex_internal(left_vnv, vn) + f = gettransform(left_vnv, vn) + is_unconstrained[idx] = istrans(left_vnv, vn) + else + # `vn` is either in both or just `right`. + # Note that in a `merge` the right value has precedence. + val = getindex_internal(right_vnv, vn) + f = gettransform(right_vnv, vn) + is_unconstrained[idx] = istrans(right_vnv, vn) + end + n = length(val) + r = (offset + 1):(offset + n) + # Update. + append!(vals, val) + push!(ranges, r) + push!(transforms, f) + # Increment `offset`. + offset += n + end + + return VarNamedVector( + varname_to_index, vns_both, ranges, vals, transforms, is_unconstrained + ) +end + +""" + subset(vnv::VarNamedVector, vns::AbstractVector{<:VarName}) + +Return a new `VarNamedVector` containing the values from `vnv` for variables in `vns`. + +Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning +that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. + +# Examples + +```jldoctest varnamedvector-subset +julia> using DynamicPPL: VarNamedVector, @varname, subset + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); + +julia> subset(vnv, [@varname(x)]) == VarNamedVector(@varname(x) => [1.0, 2.0]) +true + +julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) +true +""" +function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:VarName} + # NOTE: This does not specialize types when possible. + vns = mapreduce(vcat, vns_given; init=VN[]) do vn + filter(Base.Fix1(subsumes, vn), vnv.varnames) + end + vnv_new = similar(vnv) + # Return early if possible. + isempty(vnv) && return vnv_new + + for vn in vns + insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) + settrans!(vnv_new, istrans(vnv, vn), vn) + end + + return vnv_new +end + +""" + similar(vnv::VarNamedVector) + +Return a new `VarNamedVector` with the same structure as `vnv`, but with empty values. + +In this respect `vnv` behaves more like a dictionary than an array: `similar(vnv)` will +be entirely empty, rather than have `undef` values in it. + +# Examples + +```julia-doctest-varnamedvector-similar +julia> using DynamicPPL: VarNamedVector, @varname, similar + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(x[3]) => [3.0]); + +julia> similar(vnv) == VarNamedVector{VarName{:x}, Float64}() +true +""" +function Base.similar(vnv::VarNamedVector) + # NOTE: Whether or not we should empty the underlying containers or not + # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will + # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, + # will result in non-empty vectors but with entries as `undef`. But it's + # much easier to write the rest of the code assuming that `undef` is not + # present, and so for now we empty the underlying containers, thus differing + # from the behavior of `similar` for `AbstractArray`s. + return VarNamedVector( + empty(vnv.varname_to_index), + similar(vnv.varnames, 0), + similar(vnv.ranges, 0), + similar(vnv.vals, 0), + similar(vnv.transforms, 0), + BitVector(), + empty(vnv.num_inactive), + ) +end + +""" + is_contiguous(vnv::VarNamedVector) + +Returns `true` if the underlying data of `vnv` is stored in a contiguous array. + +This is equivalent to negating [`has_inactive(vnv)`](@ref). +""" +is_contiguous(vnv::VarNamedVector) = !has_inactive(vnv) + +""" + nextrange(vnv::VarNamedVector, x) + +Return the range of `length(x)` from the end of current data in `vnv`. +""" +function nextrange(vnv::VarNamedVector, x) + offset = length(vnv.vals) + return (offset + 1):(offset + length(x)) +end + +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. +""" + _compose_no_identity(f, g) + +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. + +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. +""" +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::typeof(identity), g) = g +_compose_no_identity(f, ::typeof(identity)) = f +_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity + +""" + shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) + +Shifts the elements of `x` starting from index `start` by `n` to the right. +""" +function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) + x[(start + n):end] = x[start:(end - n)] + return x +end + +""" + shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) + +Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. +""" +function shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) + for i in (idx + 1):length(vnv.ranges) + vnv.ranges[i] = vnv.ranges[i] .+ n + end + return nothing +end + +# set!! is the function defined in utils.jl that tries to do fancy stuff with optics when +# setting the value of a generic container using a VarName. We can bypass all that because +# VarNamedVector handles VarNames natively. However, it's semantics are slightly different +# from setindex!'s: It allows resetting variables that already have a value with values of +# a different type/size. +set!!(vnv::VarNamedVector, vn::VarName, val) = reset!!(vnv, val, vn) + +function setval!(vnv::VarNamedVector, val, vn::VarName) + return setindex_internal!(vnv, tovec(val), vn) +end + +function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) + offset = 0 + for i in 1:length(ranges) + r_old = ranges[i] + ranges[i] = (offset + 1):(offset + length(r_old)) + offset += length(r_old) + end + + return ranges +end + +""" + contiguify!(vnv::VarNamedVector) + +Re-contiguify the underlying vector and shrink if possible. + +# Examples + +```jldoctest varnamedvector-contiguify +julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]); + +julia> update!(vnv, [23.0, 24.0], @varname(x)); + +julia> has_inactive(vnv) +true + +julia> length(vnv.vals) +4 + +julia> contiguify!(vnv); + +julia> has_inactive(vnv) +false + +julia> length(vnv.vals) +3 + +julia> vnv[@varname(x)] # All the values are still there. +2-element Vector{Float64}: + 23.0 + 24.0 +``` +""" +function contiguify!(vnv::VarNamedVector) + # Extract the re-contiguified values. + # NOTE: We need to do this before we update the ranges. + old_vals = copy(vnv.vals) + old_ranges = copy(vnv.ranges) + # And then we re-contiguify the ranges. + recontiguify_ranges!(vnv.ranges) + # Clear the inactive ranges. + empty!(vnv.num_inactive) + # Now we update the values. + for (old_range, new_range) in zip(old_ranges, vnv.ranges) + vnv.vals[new_range] = old_vals[old_range] + end + # And (potentially) shrink the underlying vector. + resize!(vnv.vals, vnv.ranges[end][end]) + # The rest should be left as is. + return vnv +end + +""" + group_by_symbol(vnv::VarNamedVector) + +Return a dictionary mapping symbols to `VarNamedVector`s with varnames containing that +symbol. + +# Examples + +```jldoctest varnamedvector-group-by-symbol +julia> using DynamicPPL: VarNamedVector, @varname, group_by_symbol + +julia> vnv = VarNamedVector(@varname(x) => [1.0], @varname(y) => [2.0], @varname(x[1]) => [3.0]); + +julia> d = group_by_symbol(vnv); + +julia> collect(keys(d)) +[Symbol("x"), Symbol("y")] + +julia> d[@varname(x)] == VarNamedVector(@varname(x) => [1.0], @varname(x[1]) => [3.0]) +true + +julia> d[@varname(y)] == VarNamedVector(@varname(y) => [2.0]) +true +""" +function group_by_symbol(vnv::VarNamedVector) + symbols = unique(map(getsym, vnv.varnames)) + nt_vals = map(s -> tighten_types(subset(vnv, [VarName{s}()])), symbols) + return OrderedDict(zip(symbols, nt_vals)) +end + +""" + shift_index_left!(vnv::VarNamedVector, idx::Int) + +Shift the index `idx` to the left by one and update the relevant fields. + +This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a +helper function for [`shift_subsequent_indices_left!`](@ref). + +!!! warning + This does not check if index we're shifting to is already occupied. +""" +function shift_index_left!(vnv::VarNamedVector, idx::Int) + # Shift the index in the lookup table. + vn = vnv.varnames[idx] + vnv.varname_to_index[vn] = idx - 1 + # Shift the index in the inactive ranges. + if haskey(vnv.num_inactive, idx) + # Done in increasing order => don't need to worry about + # potentially shifting the same index twice. + vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) + end +end + +""" + shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) + +Shift the indices for all variables after `idx` to the left by one and update the relevant + fields. + +This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a +helper function for [`delete!`](@ref). +""" +function shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) + # Shift the indices for all variables after `idx`. + for idx_to_shift in (idx + 1):length(vnv.varnames) + shift_index_left!(vnv, idx_to_shift) + end +end + +function Base.delete!(vnv::VarNamedVector, vn::VarName) + # Error if we don't have the variable. + !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) + + # Get the index of the variable. + idx = getidx(vnv, vn) + + # Delete the values. + r_start = first(getrange(vnv, idx)) + n_allocated = num_allocated(vnv, idx) + # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. + deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) + + # Delete `vn` from the lookup table. + delete!(vnv.varname_to_index, vn) + + # Delete any inactive ranges corresponding to `vn`. + haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) + + # Re-adjust the indices for varnames occuring after `vn` so + # that they point to the correct indices after the deletions below. + shift_subsequent_indices_left!(vnv, idx) + + # Re-adjust the ranges for varnames occuring after `vn`. + shift_subsequent_ranges_by!(vnv, idx, -n_allocated) + + # Delete references from vector fields, thus shifting the indices of + # varnames occuring after `vn` by one to the left, as we adjusted for above. + deleteat!(vnv.varnames, idx) + deleteat!(vnv.ranges, idx) + deleteat!(vnv.transforms, idx) + + return vnv +end + +""" + values_as(vnv::VarNamedVector[, T]) + +Return the values/realizations in `vnv` as type `T`, if implemented. + +If no type `T` is provided, return values as stored in `vnv`. + +# Examples + +```jldoctest +julia> using DynamicPPL: VarNamedVector + +julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]); + +julia> values_as(vnv) == [1.0, 2.0] +true + +julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) +true + +julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) +true + +julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) +true +``` +""" +values_as(vnv::VarNamedVector) = values_as(vnv, Vector) +values_as(vnv::VarNamedVector, ::Type{Vector}) = getindex_internal(vnv, :) +function values_as(vnv::VarNamedVector, ::Type{Vector{T}}) where {T} + return convert(Vector{T}, values_as(vnv, Vector)) +end +function values_as(vnv::VarNamedVector, ::Type{NamedTuple}) + return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) +end +function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} + return ConstructionBase.constructorof(D)(pairs(vnv)) +end + +# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how +# they differ from `haskey` and `getindex`. They can be found in src/utils.jl. + +# TODO(mhauru) This is tricky to implement in the general case, and the below implementation +# only covers some simple cases. It's probably sufficient in most situations though. +function hasvalue(vnv::VarNamedVector, vn::VarName) + haskey(vnv, vn) && return true + any(subsumes(vn, k) for k in keys(vnv)) && return true + # Handle the easy case where the right symbol isn't even present. + !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false + + optic = getoptic(vn) + if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic + # If vn is of the form @varname(somesymbol[someindex]), we check whether we store + # @varname(somesymbol) and can index into it with someindex. If we rather have a + # composed optic with the last part being an index lens, we do a similar check but + # stripping out the last index lens part. If these pass, the answer is definitely + # "yes". If not, we still don't know for sure. + # TODO(mhauru) What about casese where vnv stores both @varname(x) and + # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently + # aren't. + head, tail = if optic isa Accessors.ComposedOptic + decomp_optic = Accessors.decompose(optic) + first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) + else + optic, identity + end + parent_varname = VarName{getsym(vn)}(tail) + if haskey(vnv, parent_varname) + valvec = getindex(vnv, parent_varname) + return canview(head, valvec) + end + end + throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) +end + +# TODO(mhauru) Like hasvalue, this is only partially implemented. +function getvalue(vnv::VarNamedVector, vn::VarName) + !hasvalue(vnv, vn) && throw(KeyError(vn)) + haskey(vnv, vn) && getindex(vnv, vn) + + subsumed_keys = filter(k -> subsumes(vn, k), keys(vnv)) + if length(subsumed_keys) > 0 + # TODO(mhauru) What happens if getindex returns e.g. matrices, and we vcat them? + return mapreduce(k -> getindex(vnv, k), vcat, subsumed_keys) + end + + optic = getoptic(vn) + # See hasvalue for some comments on the logic of this if block. + if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic + head, tail = if optic isa Accessors.ComposedOptic + decomp_optic = Accessors.decompose(optic) + first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) + else + optic, identity + end + parent_varname = VarName{getsym(vn)}(tail) + valvec = getindex(vnv, parent_varname) + return head(valvec) + end + throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) +end + +Base.get(vnv::VarNamedVector, vn::VarName) = getvalue(vnv, vn) diff --git a/test/Project.toml b/test/Project.toml index f0e978af8..36ee4baa8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -31,6 +32,7 @@ AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" Bijectors = "0.13.9" +Combinatorics = "1" Compat = "4.3.0" Distributions = "0.25" DistributionsAD = "0.6.3" diff --git a/test/compiler.jl b/test/compiler.jl index f1f06eabe..f2d7e5852 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -309,11 +309,11 @@ module Issue537 end vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) @test haskey(vi1.metadata, :y) - @test vi1.metadata.y.vns[1] == @varname(y) + @test first(Base.keys(vi1.metadata.y)) == @varname(y) @test haskey(vi2.metadata, :y) - @test vi2.metadata.y.vns[1] == @varname(y[2][:, 1]) + @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) @test haskey(vi3.metadata, :y) - @test vi3.metadata.y.vns[1] == @varname(y[1]) + @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) # Conditioning f1_c = f1() | (y=1,) diff --git a/test/model.jl b/test/model.jl index 60a8d2461..d163f55f0 100644 --- a/test/model.jl +++ b/test/model.jl @@ -122,7 +122,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test logjoints[i] ≈ DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) end - println("\n model $(model) passed !!! \n") end end @@ -200,10 +199,10 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints" begin + @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - vi = VarInfo(model) spl = SampleFromPrior() + vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) link!!(vi, spl, model) for i in 1:10 @@ -216,6 +215,14 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "Dynamic constraints, VectorVarInfo" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + for i in 1:10 + vi = VarInfo(model) + @test vi[@varname(x)] >= vi[@varname(m)] + end + end + @testset "rand" begin model = gdemo_default @@ -324,7 +331,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true chain = MCMCChains.Chains( permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,) ) - display(chain) # Test! results = generated_quantities(model, chain) @@ -345,7 +351,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true vcat(syms, [:y]); info=(varname_to_symbol=vns_to_syms_with_extra,), ) - display(chain_with_extra) # Test! results = generated_quantities(model, chain_with_extra) for (x_true, result) in zip(xs, results) @@ -358,6 +363,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] + context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -366,18 +372,16 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test ( - @inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())); + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) true - ) + end varinfo_linked = DynamicPPL.link(varinfo, model) - @test ( - @inferred( - DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()) - ); + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) true - ) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index b9a1d92bd..099c96f78 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,8 @@ using Test using Distributions using LinearAlgebra # Diagonal +using Combinatorics: combinations + using DynamicPPL: getargs_dottilde, getargs_tilde, Selector const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) @@ -40,6 +42,7 @@ include("test_util.jl") @testset "interface" begin include("utils.jl") include("compiler.jl") + include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 5ce112941..4343563eb 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -56,13 +56,43 @@ @test !haskey(svi, @varname(m.a[2])) @test !haskey(svi, @varname(m.a.b)) end + + @testset "VarNamedVector" begin + svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the + # next test is here to remind of us that. + svi = SimpleVarInfo( + push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) + ) + @test_broken !haskey(svi, @varname(m.a.b.c.d)) + end end @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) + SimpleVarInfo(Dict()), + SimpleVarInfo(values_constrained), + SimpleVarInfo(DynamicPPL.VarNamedVector()), + VarInfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) @@ -115,12 +145,19 @@ # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) + vnv = DynamicPPL.VarNamedVector() + for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) + vnv = push!!(vnv, VarName{k}() => v) + end + svi_vnv = SimpleVarInfo(vnv) @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( svi_nt, svi_dict, - DynamicPPL.settrans!!(svi_nt, true), - DynamicPPL.settrans!!(svi_dict, true), + svi_vnv, + DynamicPPL.settrans!!(deepcopy(svi_nt), true), + DynamicPPL.settrans!!(deepcopy(svi_dict), true), + DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) # RandOM seed is set in each `@testset`, so we need to sample # a new realization for `m` here. @@ -195,30 +232,34 @@ model = DynamicPPL.TestUtils.demo_dynamic_constraint() # Initialize. - svi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext())) - - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) + svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + + for svi in (svi_nt, svi_vnv) + # Sample with large variations in unconstrained space. + for i in 1:10 + for vn in keys(svi) + svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) + end + retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + @test retval.m == svi[@varname(m)] # `m` is unconstrained + @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.m, retval.x + ) - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) + # Realizations from model should all be equal to the unconstrained realization. + for vn in DynamicPPL.TestUtils.varnames(model) + @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + end - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + # `getlogp` should be equal to the logjoint with log-absdet-jac correction. + lp = getlogp(svi) + @test lp ≈ lp_true end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) - @test lp ≈ lp_true end end diff --git a/test/test_util.jl b/test/test_util.jl index 64832f51e..f1325b729 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -84,10 +84,17 @@ Return string representing a short description of `vi`. """ short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = "threadsafe($(short_varinfo_name(vi.varinfo)))" -short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" +function short_varinfo_name(vi::TypedVarInfo) + DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" + return "TypedVarInfo" +end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) + return "SimpleVarInfo{<:VarNamedVector}" +end # convenient functions for testing model.jl # function to modify the representation of values based on their length diff --git a/test/varinfo.jl b/test/varinfo.jl index 6a3d8d2bc..65f849dda 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -19,7 +19,7 @@ struct MySAlg end DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "varinfo.jl" begin - @testset "TypedVarInfo" begin + @testset "TypedVarInfo with Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -28,7 +28,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end model = gdemo(1.0, 2.0) - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) model(vi, SampleFromUniform()) tvi = TypedVarInfo(vi) @@ -51,6 +51,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + @testset "Base" begin # Test Base functions: # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, @@ -110,6 +111,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r + # TODO(mhauru) Implement these functions for other VarInfo types too. + if vi isa DynamicPPL.VectorVarInfo + delete!(vi, vn) + @test isempty(vi) + vi = push!!(vi, vn, r, dist, gid) + end + vi = empty!!(vi) @test isempty(vi) return push!!(vi, vn, r, dist, gid) @@ -120,6 +128,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(TypedVarInfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) + test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "flags" begin # Test flag setting: @@ -141,12 +150,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata vn = @varname x dist = Normal(0, 1) @@ -196,18 +205,36 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo(model) - vi_untyped = VarInfo() + vi_typed = VarInfo( + model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() + ) + vi_untyped = VarInfo(DynamicPPL.Metadata()) + vi_vnv = VarInfo(DynamicPPL.VarNamedVector()) + vi_vnv_typed = VarInfo( + model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector() + ) model(vi_untyped, SampleFromPrior()) + model(vi_vnv, SampleFromPrior()) - for vi in [vi_untyped, vi_typed] + model_name = model == model_uv ? "univariate" : "multivariate" + @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ + vi_untyped, vi_typed, vi_vnv, vi_vnv_typed + ] + Random.seed!(23) vicopy = deepcopy(vi) ### `setval` ### - DynamicPPL.setval!(vicopy, (m=zeros(5),)) + # TODO(mhauru) The interface here seems inconsistent between Metadata and + # VarNamedVector. I'm lazy to fix it though, because I think we need to + # rework it soon anyway. + if vi in [vi_vnv, vi_vnv_typed] + DynamicPPL.setval!(vicopy, zeros(5), m_vns) + else + DynamicPPL.setval!(vicopy, (m=zeros(5),)) + end # Setting `m` fails for univariate due to limitations of `setval!` # and `setval_and_resample!`. See docstring of `setval!` for more info. - if model == model_uv + if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @test vicopy[m_vns] == zeros(5) @@ -240,6 +267,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) continue end + if vi in [vi_vnv, vi_vnv_typed] + # `setval_and_resample!` works differently for `VarNamedVector`: All + # values will be resampled when model(vicopy) is called. Hence the below + # tests are not applicable. + continue + end + vicopy = deepcopy(vi) DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) model(vicopy) @@ -338,6 +372,14 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `SimpleVarInfo{<:VarNamedVector}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -409,6 +451,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) continue end + if DynamicPPL.has_varnamedvector(varinfo) && mutating + # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. + @test_broken false + continue + end + # Evaluate the model once to update the logp of the varinfo. varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) @@ -636,6 +684,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) varinfo_left = VarInfo(model_left) varinfo_right = VarInfo(model_right) + varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) vns = [@varname(x), @varname(y), @varname(z)] @@ -643,13 +692,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Right has precedence. @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] - @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal + @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end end @testset "VarInfo with selectors" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(model) + varinfo = VarInfo( + model, + DynamicPPL.SampleFromPrior(), + DynamicPPL.DefaultContext(), + DynamicPPL.Metadata(), + ) selector = DynamicPPL.Selector() spl = Sampler(MySAlg(), model, selector) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl new file mode 100644 index 000000000..bd3f5553f --- /dev/null +++ b/test/varnamedvector.jl @@ -0,0 +1,626 @@ +replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens) + +increase_size_for_test(x::Real) = [x] +increase_size_for_test(x::AbstractArray) = repeat(x, 2) + +decrease_size_for_test(x::Real) = x +decrease_size_for_test(x::AbstractVector) = first(x) +decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) + +function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) + if isconcretetype(eltype(vnv.varnames)) + # If the container is concrete, we need to make sure that the varname types match. + # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then + # we need `vn` to also be of this type. + # => If the varname types don't match, we need to relax the container type. + return any(keys(vnv)) do vn_present + typeof(vn_present) !== typeof(val) + end + end + + return false +end +function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) + return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) +end + +function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) + if isconcretetype(eltype(vnv.vals)) + return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) + end + + return false +end +function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) + return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) +end + +function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) + return if isconcretetype(eltype(vnv.transforms)) + # If the container is concrete, we need to make sure that the sizes match. + # => If the sizes don't match, we need to relax the container type. + any(keys(vnv)) do vn_present + size(vnv[vn_present]) != size(val) + end + elseif eltype(vnv.transforms) !== Any + # If it's not concrete AND it's not `Any`, then we should just make it `Any`. + true + else + # Otherwise, it's `Any`, so we don't need to relax the container type. + false + end +end +function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) + return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) +end + +""" + relax_container_types(vnv::VarNamedVector, vn::VarName, val) + relax_container_types(vnv::VarNamedVector, vns, val) + +Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. + +This attempts to avoid unnecessary container type relaxations by checking whether +the container types of `vnv` are already compatible with `vn` and `val`. + +# Notes +For example, if `vn` is not compatible with the current keys in `vnv`, then +the underlying types will be changed to `VarName` to accommodate `vn`. + +Similarly: +- If `val` is not compatible with the current values in `vnv`, then + the underlying value type will be changed to `Real`. +- If `val` requires a transformation that is not compatible with the current + transformations type in `vnv`, then the underlying transformation type will + be changed to `Any`. +""" +function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) + return relax_container_types(vnv, [vn], [val]) +end +function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) + if need_varnames_relaxation(vnv, vns, vals) + varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index) + varnames_new = convert(Vector{VarName}, vnv.varnames) + else + varname_to_index_new = vnv.varname_to_index + varnames_new = vnv.varnames + end + + transforms_new = if need_transforms_relaxation(vnv, vns, vals) + convert(Vector{Any}, vnv.transforms) + else + vnv.transforms + end + + vals_new = if need_values_relaxation(vnv, vns, vals) + convert(Vector{Real}, vnv.vals) + else + vnv.vals + end + + return DynamicPPL.VarNamedVector( + varname_to_index_new, + varnames_new, + vnv.ranges, + vals_new, + transforms_new, + vnv.is_unconstrained, + vnv.num_inactive, + ) +end + +@testset "VarNamedVector" begin + # Test element-related operations: + # - `getindex` + # - `setindex!` + # - `push!` + # - `update!` + # - `insert!` + # - `reset!` + # - `_internal!` versions of the above + # - !! versions of the above + # + # And these are all be tested for different types of values: + # - scalar + # - vector + # - matrix + + # Test operations on `VarNamedVector`: + # - `empty!` + # - `iterate` + # - `convert` to + # - `AbstractDict` + test_pairs = OrderedDict( + @varname(x[1]) => rand(), + @varname(x[2]) => rand(2), + @varname(x[3]) => rand(2, 3), + @varname(y[1]) => rand(), + @varname(y[2]) => rand(2), + @varname(y[3]) => rand(2, 3), + @varname(z[1]) => rand(1:10), + @varname(z[2]) => rand(1:10, 2), + @varname(z[3]) => rand(1:10, 2, 3), + ) + test_vns = collect(keys(test_pairs)) + test_vals = collect(values(test_pairs)) + + @testset "constructor: no args" begin + # Empty. + vnv = DynamicPPL.VarNamedVector() + @test isempty(vnv) + @test eltype(vnv) == Real + + # Empty with types. + vnv = DynamicPPL.VarNamedVector{VarName,Float64}() + @test isempty(vnv) + @test eltype(vnv) == Float64 + end + + test_varnames_iter = combinations(test_vns, 2) + @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter + val_left = test_pairs[vn_left] + val_right = test_pairs[vn_right] + vnv_base = DynamicPPL.VarNamedVector([vn_left, vn_right], [val_left, val_right]) + + # We'll need the transformations later. + # TODO: Should we test other transformations than just `ReshapeTransform`? + from_vec_left = DynamicPPL.from_vec_transform(val_left) + from_vec_right = DynamicPPL.from_vec_transform(val_right) + to_vec_left = inverse(from_vec_left) + to_vec_right = inverse(from_vec_right) + + # Compare to alternative constructors. + vnv_from_dict = DynamicPPL.VarNamedVector( + OrderedDict(vn_left => val_left, vn_right => val_right) + ) + @test vnv_base == vnv_from_dict + + # We want the types of fields such as `varnames` and `transforms` to specialize + # whenever possible + some functionality, e.g. `push!`, is only sensible + # if the underlying containers can support it. + # Expected behavior + should_have_restricted_varname_type = typeof(vn_left) == typeof(vn_right) + should_have_restricted_transform_type = size(val_left) == size(val_right) + # Actual behavior + has_restricted_transform_type = isconcretetype(eltype(vnv_base.transforms)) + has_restricted_varname_type = isconcretetype(eltype(vnv_base.varnames)) + + @testset "type specialization" begin + @test !should_have_restricted_varname_type || has_restricted_varname_type + @test !should_have_restricted_transform_type || has_restricted_transform_type + end + + @test eltype(vnv_base) == promote_type(eltype(val_left), eltype(val_right)) + @test DynamicPPL.length_internal(vnv_base) == length(val_left) + length(val_right) + @test length(vnv_base) == 2 + + @test !isempty(vnv_base) + + @testset "empty!" begin + vnv = deepcopy(vnv_base) + empty!(vnv) + @test isempty(vnv) + end + + @testset "similar" begin + vnv = similar(vnv_base) + @test isempty(vnv) + @test typeof(vnv) == typeof(vnv_base) + end + + @testset "getindex" begin + # With `VarName` index. + @test vnv_base[vn_left] == val_left + @test vnv_base[vn_right] == val_right + end + + @testset "getindex_internal" begin + @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_left) == + to_vec_left(val_left) + @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_right) == + to_vec_right(val_right) + end + + @testset "getindex_internal with Ints" begin + for (i, val) in enumerate(to_vec_left(val_left)) + @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, i) == val + end + offset = length(to_vec_left(val_left)) + for (i, val) in enumerate(to_vec_right(val_right)) + @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, offset + i) == val + end + end + + @testset "update!" begin + vnv = deepcopy(vnv_base) + DynamicPPL.update!(vnv, val_left .+ 100, vn_left) + @test vnv[vn_left] == val_left .+ 100 + DynamicPPL.update!(vnv, val_right .+ 100, vn_right) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "update!!" begin + vnv = deepcopy(vnv_base) + vnv = DynamicPPL.update!!(vnv, val_left .+ 100, vn_left) + @test vnv[vn_left] == val_left .+ 100 + vnv = DynamicPPL.update!!(vnv, val_right .+ 100, vn_right) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "update_internal!" begin + vnv = deepcopy(vnv_base) + DynamicPPL.update_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) + @test vnv[vn_left] == val_left .+ 100 + DynamicPPL.update_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "update_internal!!" begin + vnv = deepcopy(vnv_base) + vnv = DynamicPPL.update_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) + @test vnv[vn_left] == val_left .+ 100 + vnv = DynamicPPL.update_internal!!( + vnv, to_vec_right(val_right .+ 100), vn_right + ) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "delete!" begin + vnv = deepcopy(vnv_base) + delete!(vnv, vn_left) + @test !haskey(vnv, vn_left) + @test haskey(vnv, vn_right) + delete!(vnv, vn_right) + @test !haskey(vnv, vn_right) + end + + @testset "insert!" begin + vnv = deepcopy(vnv_base) + delete!(vnv, vn_left) + delete!(vnv, vn_right) + DynamicPPL.insert!(vnv, val_left .+ 100, vn_left) + @test vnv[vn_left] == val_left .+ 100 + DynamicPPL.insert!(vnv, val_right .+ 100, vn_right) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "insert!!" begin + vnv = deepcopy(vnv_base) + delete!(vnv, vn_left) + delete!(vnv, vn_right) + vnv = DynamicPPL.insert!!(vnv, val_left .+ 100, vn_left) + @test vnv[vn_left] == val_left .+ 100 + vnv = DynamicPPL.insert!!(vnv, val_right .+ 100, vn_right) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "insert_internal!" begin + vnv = deepcopy(vnv_base) + delete!(vnv, vn_left) + delete!(vnv, vn_right) + DynamicPPL.insert_internal!( + vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left + ) + @test vnv[vn_left] == val_left .+ 100 + DynamicPPL.insert_internal!( + vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right + ) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "insert_internal!!" begin + vnv = deepcopy(vnv_base) + delete!(vnv, vn_left) + delete!(vnv, vn_right) + vnv = DynamicPPL.insert_internal!!( + vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left + ) + @test vnv[vn_left] == val_left .+ 100 + vnv = DynamicPPL.insert_internal!!( + vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right + ) + @test vnv[vn_right] == val_right .+ 100 + end + + @testset "merge" begin + # When there are no inactive entries, `merge` on itself result in the same. + @test merge(vnv_base, vnv_base) == vnv_base + + # Merging with empty should result in the same. + @test merge(vnv_base, similar(vnv_base)) == vnv_base + @test merge(similar(vnv_base), vnv_base) == vnv_base + + # With differences. + vnv_left_only = deepcopy(vnv_base) + delete!(vnv_left_only, vn_right) + vnv_right_only = deepcopy(vnv_base) + delete!(vnv_right_only, vn_left) + + # `(x,)` and `(x, y)` should be `(x, y)`. + @test merge(vnv_left_only, vnv_base) == vnv_base + # `(x, y)` and `(x,)` should be `(x, y)`. + @test merge(vnv_base, vnv_left_only) == vnv_base + # `(x, y)` and `(y,)` should be `(x, y)`. + @test merge(vnv_base, vnv_right_only) == vnv_base + # `(y,)` and `(x, y)` should be `(y, x)`. + vnv_merged = merge(vnv_right_only, vnv_base) + @test vnv_merged != vnv_base + @test collect(keys(vnv_merged)) == [vn_right, vn_left] + end + + @testset "push!" begin + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn" for vn in test_vns + val = test_pairs[vn] + vnv_copy = deepcopy(vnv) + push!(vnv, (vn => val)) + @test vnv[vn] == val + end + end + + @testset "setindex_internal!" begin + # Not setting the transformation. + vnv = deepcopy(vnv_base) + DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) + @test vnv[vn_left] == val_left .+ 100 + DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) + @test vnv[vn_right] == val_right .+ 100 + + # Explicitly setting the transformation. + increment(x) = x .+ 10 + vnv = deepcopy(vnv_base) + vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), typeof(increment)) + DynamicPPL.setindex_internal!( + vnv, to_vec_left(val_left .+ 100), vn_left, increment + ) + @test vnv[vn_left] == to_vec_left(val_left .+ 110) + + vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), typeof(increment)) + DynamicPPL.setindex_internal!( + vnv, to_vec_right(val_right .+ 100), vn_right, increment + ) + @test vnv[vn_right] == to_vec_right(val_right .+ 110) + + # Adding new values. + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn" for vn in test_vns + val = test_pairs[vn] + from_vec_vn = DynamicPPL.from_vec_transform(val) + to_vec_vn = inverse(from_vec_vn) + DynamicPPL.setindex_internal!(vnv, to_vec_vn(val), vn, from_vec_vn) + @test vnv[vn] == val + end + end + + @testset "setindex_internal! with Ints" begin + vnv = deepcopy(vnv_base) + for i in 1:DynamicPPL.length_internal(vnv_base) + DynamicPPL.setindex_internal!(vnv, i, i) + end + for i in 1:DynamicPPL.length_internal(vnv_base) + @test DynamicPPL.getindex_internal(vnv, i) == i + end + end + + @testset "setindex_internal!!" begin + # Not setting the transformation. + vnv = deepcopy(vnv_base) + vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) + @test vnv[vn_left] == val_left .+ 100 + vnv = DynamicPPL.setindex_internal!!( + vnv, to_vec_right(val_right .+ 100), vn_right + ) + @test vnv[vn_right] == val_right .+ 100 + + # Explicitly setting the transformation. + # Note that unlike with setindex_internal!, we don't need loosen_types!! here. + increment(x) = x .+ 10 + vnv = deepcopy(vnv_base) + vnv = DynamicPPL.setindex_internal!!( + vnv, to_vec_left(val_left .+ 100), vn_left, increment + ) + @test vnv[vn_left] == to_vec_left(val_left .+ 110) + + vnv = DynamicPPL.setindex_internal!!( + vnv, to_vec_right(val_right .+ 100), vn_right, increment + ) + @test vnv[vn_right] == to_vec_right(val_right .+ 110) + + # Adding new values. + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn" for vn in test_vns + val = test_pairs[vn] + from_vec_vn = DynamicPPL.from_vec_transform(val) + to_vec_vn = inverse(from_vec_vn) + vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_vn(val), vn, from_vec_vn) + @test vnv[vn] == val + end + end + + @testset "setindex! and reset!" begin + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn" for vn in test_vns + val = test_pairs[vn] + expected_length = if haskey(vnv, vn) + # If it's already present, the resulting length will be unchanged. + DynamicPPL.length_internal(vnv) + else + DynamicPPL.length_internal(vnv) + length(val) + end + + vnv[vn] = val .+ 1 + x = DynamicPPL.getindex_internal(vnv, :) + @test vnv[vn] == val .+ 1 + @test DynamicPPL.length_internal(vnv) == expected_length + @test length(x) == DynamicPPL.length_internal(vnv) + @test all( + DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) + ) + + # There should be no redundant values in the underlying vector. + @test !DynamicPPL.has_inactive(vnv) + end + + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn (increased size)" for vn in test_vns + val_original = test_pairs[vn] + val = increase_size_for_test(val_original) + vn_already_present = haskey(vnv, vn) + expected_length = if vn_already_present + # If it's already present, the resulting length will be altered. + DynamicPPL.length_internal(vnv) + length(val) - length(val_original) + else + DynamicPPL.length_internal(vnv) + length(val) + end + + # Have to use reset!, because setindex! doesn't support decreasing size. + DynamicPPL.reset!(vnv, val .+ 1, vn) + x = DynamicPPL.getindex_internal(vnv, :) + @test vnv[vn] == val .+ 1 + @test DynamicPPL.length_internal(vnv) == expected_length + @test length(x) == DynamicPPL.length_internal(vnv) + @test all( + DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) + ) + end + + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn (decreased size)" for vn in test_vns + val_original = test_pairs[vn] + val = decrease_size_for_test(val_original) + vn_already_present = haskey(vnv, vn) + expected_length = if vn_already_present + # If it's already present, the resulting length will be altered. + DynamicPPL.length_internal(vnv) + length(val) - length(val_original) + else + DynamicPPL.length_internal(vnv) + length(val) + end + + # Have to use reset!, because setindex! doesn't support decreasing size. + DynamicPPL.reset!(vnv, val .+ 1, vn) + x = DynamicPPL.getindex_internal(vnv, :) + @test vnv[vn] == val .+ 1 + @test DynamicPPL.length_internal(vnv) == expected_length + @test length(x) == DynamicPPL.length_internal(vnv) + @test all( + DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) + ) + end + end + end + + @testset "growing and shrinking" begin + @testset "deterministic" begin + n = 5 + vn = @varname(x) + vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) + @test !DynamicPPL.has_inactive(vnv) + # Growing should not create inactive ranges. + for i in 1:n + x = fill(true, i) + DynamicPPL.update_internal!(vnv, x, vn, identity) + @test !DynamicPPL.has_inactive(vnv) + end + + # Same size should not create inactive ranges. + x = fill(true, n) + DynamicPPL.update_internal!(vnv, x, vn, identity) + @test !DynamicPPL.has_inactive(vnv) + + # Shrinking should create inactive ranges. + for i in (n - 1):-1:1 + x = fill(true, i) + DynamicPPL.update_internal!(vnv, x, vn, identity) + @test DynamicPPL.has_inactive(vnv) + @test DynamicPPL.num_inactive(vnv, vn) == n - i + end + end + + @testset "random" begin + n = 5 + vn = @varname(x) + vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) + @test !DynamicPPL.has_inactive(vnv) + + # Insert a bunch of random-length vectors. + for i in 1:100 + x = fill(true, rand(1:n)) + DynamicPPL.update!(vnv, x, vn) + end + # Should never be allocating more than `n` elements. + @test DynamicPPL.num_allocated(vnv, vn) ≤ n + + # If we compaticfy, then it should always be the same size as just inserted. + for i in 1:10 + x = fill(true, rand(1:n)) + DynamicPPL.update!(vnv, x, vn) + DynamicPPL.contiguify!(vnv) + @test DynamicPPL.num_allocated(vnv, vn) == length(x) + end + end + end + + @testset "subset" begin + vnv = DynamicPPL.VarNamedVector(test_pairs) + @test subset(vnv, test_vns) == vnv + @test subset(vnv, VarName[]) == DynamicPPL.VarNamedVector() + @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv + + # Test that subset preserves transformations and unconstrainedness. + vn = @varname(t[1]) + vns = vcat(test_vns, [vn]) + vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) + DynamicPPL.settrans!(vnv, true, @varname(t[1])) + @test vnv[@varname(t[1])] == [4.0] + @test istrans(vnv, @varname(t[1])) + @test subset(vnv, vns) == vnv + end +end + +@testset "VarInfo + VarNamedVector" begin + models = DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in models + # NOTE: Need to set random seed explicitly to avoid using the same seed + # for initialization as for sampling in the inner testset below. + Random.seed!(42) + value_true = DynamicPPL.TestUtils.rand_prior_true(model) + vns = DynamicPPL.TestUtils.varnames(model) + varnames = DynamicPPL.TestUtils.varnames(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, value_true, varnames; include_threadsafe=false + ) + # Filter out those which are not based on `VarNamedVector`. + varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) + # Get the true log joint. + logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + # Need to make sure we're using a different random seed from the + # one used in the above call to `rand_prior_true`. + Random.seed!(43) + + # Are values correct? + DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) + + # Is evaluation correct? + varinfo_eval = last( + DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) + ) + # Log density should be the same. + @test getlogp(varinfo_eval) ≈ logp_true + # Values should be the same. + DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) + + # Is sampling correct? + varinfo_sample = last( + DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) + ) + # Log density should be different. + @test getlogp(varinfo_sample) != getlogp(varinfo) + # Values should be different. + DynamicPPL.TestUtils.test_values( + varinfo_sample, value_true, vns; compare=!isequal + ) + end + end +end