Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge libtask_ext upstream #2115

Merged
merged 21 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ TuringDynamicHMCExt = "DynamicHMC"
TuringOptimExt = "Optim"

[compat]
AbstractMCMC = "4"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2"
AdvancedMH = "0.6.8, 0.7"
AdvancedPS = "0.4"
AbstractMCMC = "4, 5"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
AdvancedVI = "0.2"
BangBang = "0.3"
Bijectors = "0.13.6"
Expand All @@ -56,8 +56,8 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.23.17"
EllipticalSliceSampling = "0.5, 1"
DynamicPPL = "0.24"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
LogDensityProblems = "2"
Expand Down
1 change: 0 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ export @model, # modelling
ADVI,

sample, # inference
resume,
@logprob_str,
@prob_str,
externalsampler,
Expand Down
18 changes: 7 additions & 11 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple}
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
model::M
sampler::S
varinfo::V
Expand All @@ -24,14 +24,10 @@ function TracedModel(
)
end

function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
newtask = copy(model.ctask)
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
gen_model = AdvancedPS.GenericModel(newmodel, newtask)
return gen_model
end

function AdvancedPS.advance!(trace::AdvancedPS.Trace{<:AdvancedPS.GenericModel{<:TracedModel}}, isref::Bool=false)
function AdvancedPS.advance!(
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}},
isref::Bool=false
)
# Make sure we load/reset the rng in the new replaying mechanism
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
Expand All @@ -58,7 +54,7 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
return trace
end

function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}})
# Extract the `args`.
args = trace.model.ctask.args
# From `args`, extract the `SamplingContext`, which contains the RNG.
Expand All @@ -68,6 +64,6 @@ function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{
return trace
end

function Libtask.TapedTask(model::TracedModel, rng::Random.AbstractRNG; kwargs...)
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
end
83 changes: 21 additions & 62 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ export InferenceAlgorithm,
dot_assume,
observe,
dot_observe,
resume,
predict,
isgibbscomponent,
externalsampler
Expand Down Expand Up @@ -190,42 +189,6 @@ function AbstractMCMC.sample(
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
sampler::Sampler{<:InferenceAlgorithm},
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
if resume_from === nothing
return AbstractMCMC.mcmcsample(rng, model, sampler, N;
chain_type=chain_type, progress=progress, kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...)
end
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::Prior,
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
if resume_from === nothing
return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N;
chain_type=chain_type, progress=progress, kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...)
end
end

function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
Expand Down Expand Up @@ -273,17 +236,36 @@ function AbstractMCMC.sample(
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
chain_type=MCMCChains.Chains,
chain_type=DynamicPPL.default_chain_type(alg),
progress=PROGRESS[],
kwargs...
)
return AbstractMCMC.sample(rng, model, SampleFromPrior(), ensemble, N, n_chains;
chain_type=chain_type, progress=progress, kwargs...)
chain_type, progress, kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::Prior,
N::Integer;
chain_type=DynamicPPL.default_chain_type(alg),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
kwargs...
)
return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N;
chain_type, initial_state, progress, kwargs...)
end

##########################
# Chain making utilities #
##########################

DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains
DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains

"""
getparams(model, t)

Expand Down Expand Up @@ -477,29 +459,6 @@ function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
return setinfo(c, merge(nt, c.info))
end

function resume(chain::MCMCChains.Chains, args...; kwargs...)
return resume(Random.default_rng(), chain, args...; kwargs...)
end

function resume(rng::Random.AbstractRNG, chain::MCMCChains.Chains, args...;
progress=PROGRESS[], kwargs...)
isempty(chain.info) && error("[Turing] cannot resume from a chain without state info")

# Sample a new chain.
return AbstractMCMC.mcmcsample(
rng,
chain.info[:model],
chain.info[:sampler],
args...;
resume_from = chain,
chain_type = MCMCChains.Chains,
progress = progress,
kwargs...
)
end

DynamicPPL.loadstate(chain::MCMCChains.Chains) = chain.info[:samplerstate]

#######################################
# Concrete algorithm implementations. #
#######################################
Expand Down
12 changes: 8 additions & 4 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
# Handle setting `nadapts` and `discard_initial`
function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
model::DynamicPPL.Model,
sampler::Sampler{<:AdaptiveHamiltonian},
N::Integer;
chain_type=MCMCChains.Chains,
chain_type=DynamicPPL.default_chain_type(sampler),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
nadapts=sampler.alg.n_adapts,
discard_adapt=true,
Expand Down Expand Up @@ -123,8 +124,11 @@ function AbstractMCMC.sample(
nadapts=_nadapts, discard_initial=_discard_initial,
kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=progress,
nadapts=0, discard_adapt=false, discard_initial=0, kwargs...)
return AbstractMCMC.mcmcsample(
rng, model, sampler, N;
chain_type=chain_type, initial_state=initial_state, progress=progress,
nadapts=0, discard_adapt=false, discard_initial=0, kwargs...
)
end
end

Expand Down
23 changes: 12 additions & 11 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
model::DynamicPPL.Model,
sampler::Sampler{<:SMC},
N::Integer;
chain_type=MCMCChains.Chains,
chain_type=DynamicPPL.default_chain_type(sampler),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
kwargs...
)
Expand All @@ -94,8 +95,10 @@ function AbstractMCMC.sample(
nparticles=N,
kwargs...)
else
return resume(resume_from, N;
chain_type=chain_type, progress=progress, nparticles=N, kwargs...)
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_state, progress=progress,
nparticles=N, kwargs...
)
end
end

Expand All @@ -121,7 +124,7 @@ function DynamicPPL.initialstep(
)

# Perform particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl)

# Extract the first particle and its weight.
particle = particles.vals[1]
Expand Down Expand Up @@ -264,7 +267,7 @@ function DynamicPPL.initialstep(
)

# Perform a particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl)

# Pick a particle to be retained.
Ws = AdvancedPS.getweights(particles)
Expand Down Expand Up @@ -308,7 +311,7 @@ function AbstractMCMC.step(
particles = AdvancedPS.ParticleContainer(x, AdvancedPS.TracedRNG(), rng)

# Perform a particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, reference)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl, reference)

# Pick a particle to be retained.
Ws = AdvancedPS.getweights(particles)
Expand Down Expand Up @@ -389,9 +392,7 @@ function AdvancedPS.Trace(
DynamicPPL.reset_num_produce!(newvarinfo)

tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng)
ttask = Libtask.TapedTask(tmodel, rng; deepcopy_types=Union{typeof(rng), typeof(model)})
wrapedmodel = AdvancedPS.GenericModel(tmodel, ttask)

newtrace = AdvancedPS.Trace(wrapedmodel, rng)
newtrace = AdvancedPS.Trace(tmodel, rng)
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end
8 changes: 4 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "4"
AdvancedMH = "0.6, 0.7"
AdvancedPS = "0.4"
AbstractMCMC = "4, 5"
AdvancedMH = "0.6, 0.7, 0.8"
AdvancedPS = "0.5.4"
AdvancedVI = "0.2"
Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.23"
DynamicPPL = "0.24"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down
50 changes: 50 additions & 0 deletions test/essential/container.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@testset "container.jl" begin
@model function test()
a ~ Normal(0, 1)
x ~ Bernoulli(1)
b ~ Gamma(2, 3)
1 ~ Bernoulli(x / 2)
c ~ Beta()
0 ~ Bernoulli(x / 2)
x
end

@turing_testset "constructor" begin
vi = DynamicPPL.VarInfo()
sampler = Sampler(PG(10))
model = test()
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())

# Make sure we link the traces
@test haskey(trace.model.ctask.task.storage, :__trace)
yebai marked this conversation as resolved.
Show resolved Hide resolved

res = AdvancedPS.advance!(trace, false)
@test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1
@test res ≈ -log(2)

# Catch broken copy, espetially for RNG / VarInfo
newtrace = AdvancedPS.fork(trace)
res2 = AdvancedPS.advance!(trace)
@test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 2
@test DynamicPPL.get_num_produce(newtrace.model.f.varinfo) == 1
end

@turing_testset "fork" begin
@model function normal()
a ~ Normal(0, 1)
3 ~ Normal(a, 2)
b ~ Normal(a, 1)
1.5 ~ Normal(b, 2)
a, b
end
vi = DynamicPPL.VarInfo()
sampler = Sampler(PG(10))
model = normal()

trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())

newtrace = AdvancedPS.forkr(trace)
# Catch broken replay mechanism
@test AdvancedPS.advance!(trace) ≈ AdvancedPS.advance!(newtrace)
end
end
3 changes: 0 additions & 3 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@
chn1 = sample(gdemo_default, alg1, 5000; save_state=true)
check_gdemo(chn1)

chn1_resumed = Turing.Inference.resume(chn1, 2000)
check_gdemo(chn1_resumed)

chn1_contd = sample(gdemo_default, alg1, 5000; resume_from=chn1)
check_gdemo(chn1_contd)

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AbstractMCMC
using AdvancedMH
using AdvancedPS
using Clustering
using Distributions
using Distributions.FillArrays
Expand Down
Loading