Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge libtask_ext upstream #2115

Merged
merged 21 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ TuringOptimExt = "Optim"
AbstractMCMC = "4"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2"
yebai marked this conversation as resolved.
Show resolved Hide resolved
AdvancedMH = "0.6.8, 0.7"
AdvancedPS = "0.4"
AdvancedPS = "0.5"
FredericWantiez marked this conversation as resolved.
Show resolved Hide resolved
AdvancedVI = "0.2"
BangBang = "0.3"
Bijectors = "0.13.6"
Expand Down
112 changes: 60 additions & 52 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,81 @@
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple}
model::M
sampler::S
varinfo::V
evaluator::E
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
model::M
sampler::S
varinfo::V
evaluator::E
end

function TracedModel(
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
rng::Random.AbstractRNG,
)
context = SamplingContext(rng, sampler, DefaultContext())
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
if kwargs !== nothing && !isempty(kwargs)
error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.")
end
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model,
sampler,
varinfo,
(model.f, args...)
)
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
rng::Random.AbstractRNG,
)
context = SamplingContext(rng, sampler, DefaultContext())
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
if kwargs !== nothing && !isempty(kwargs)
error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.")
end
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model,
sampler,
varinfo,
(model.f, args...)
)
end

function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
newtask = copy(model.ctask)
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
gen_model = AdvancedPS.GenericModel(newmodel, newtask)
return gen_model
function Base.copy(model::AdvancedPS.LibtaskModel{<:TracedModel})
newtask = copy(model.ctask)
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
yebai marked this conversation as resolved.
Show resolved Hide resolved
deepcopy(model.f.model),
deepcopy(model.f.sampler),
deepcopy(model.f.varinfo),
deepcopy(model.f.evaluator)
)
gen_model = AdvancedPS.LibtaskModel(newmodel, newtask)
return gen_model
end

function AdvancedPS.advance!(trace::AdvancedPS.Trace{<:AdvancedPS.GenericModel{<:TracedModel}}, isref::Bool=false)
# Make sure we load/reset the rng in the new replaying mechanism
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
score = consume(trace.model.ctask)
if score === nothing
return
else
return score + DynamicPPL.getlogp(trace.model.f.varinfo)
end
function AdvancedPS.advance!(
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}},
isref::Bool=false
)
# Make sure we load/reset the rng in the new replaying mechanism
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
score = consume(trace.model.ctask)
if score === nothing
return
else
return score + DynamicPPL.getlogp(trace.model.f.varinfo)
end
end

function AdvancedPS.delete_retained!(trace::TracedModel)
DynamicPPL.set_retained_vns_del_by_spl!(trace.varinfo, trace.sampler)
return trace
DynamicPPL.set_retained_vns_del_by_spl!(trace.varinfo, trace.sampler)
return trace
end

function AdvancedPS.reset_model(trace::TracedModel)
DynamicPPL.reset_num_produce!(trace.varinfo)
return trace
DynamicPPL.reset_num_produce!(trace.varinfo)
return trace
end

function AdvancedPS.reset_logprob!(trace::TracedModel)
DynamicPPL.resetlogp!!(trace.model.varinfo)
return trace
DynamicPPL.resetlogp!!(trace.model.varinfo)
return trace
end

function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
# Extract the `args`.
args = trace.model.ctask.args
# From `args`, extract the `SamplingContext`, which contains the RNG.
sampling_context = args[3]
rng = sampling_context.rng
trace.rng = rng
return trace
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}})
# Extract the `args`.
args = trace.model.ctask.args
# From `args`, extract the `SamplingContext`, which contains the RNG.
sampling_context = args[3]
rng = sampling_context.rng
trace.rng = rng
return trace
end

function Libtask.TapedTask(model::TracedModel, rng::Random.AbstractRNG; kwargs...)
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
end
12 changes: 5 additions & 7 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function DynamicPPL.initialstep(
)

# Perform particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl)

# Extract the first particle and its weight.
particle = particles.vals[1]
Expand Down Expand Up @@ -264,7 +264,7 @@ function DynamicPPL.initialstep(
)

# Perform a particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl)

# Pick a particle to be retained.
Ws = AdvancedPS.getweights(particles)
Expand Down Expand Up @@ -308,7 +308,7 @@ function AbstractMCMC.step(
particles = AdvancedPS.ParticleContainer(x, AdvancedPS.TracedRNG(), rng)

# Perform a particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, reference)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl, reference)

# Pick a particle to be retained.
Ws = AdvancedPS.getweights(particles)
Expand Down Expand Up @@ -389,9 +389,7 @@ function AdvancedPS.Trace(
DynamicPPL.reset_num_produce!(newvarinfo)

tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng)
ttask = Libtask.TapedTask(tmodel, rng; deepcopy_types=Union{typeof(rng), typeof(model)})
wrapedmodel = AdvancedPS.GenericModel(tmodel, ttask)

newtrace = AdvancedPS.Trace(wrapedmodel, rng)
newtrace = AdvancedPS.Trace(tmodel, rng)
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractMCMC = "4"
AdvancedMH = "0.6, 0.7"
AdvancedPS = "0.4"
AdvancedPS = "0.5.2"
AdvancedVI = "0.2"
Clustering = "0.14, 0.15"
Distributions = "0.25"
Expand Down
50 changes: 50 additions & 0 deletions test/essential/container.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@testset "container.jl" begin
@model function test()
a ~ Normal(0, 1)
x ~ Bernoulli(1)
b ~ Gamma(2, 3)
1 ~ Bernoulli(x / 2)
c ~ Beta()
0 ~ Bernoulli(x / 2)
x
end

@turing_testset "constructor" begin
vi = DynamicPPL.VarInfo()
sampler = Sampler(PG(10))
model = test()
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())

# Make sure we link the traces
@test haskey(trace.model.ctask.task.storage, :__trace)
yebai marked this conversation as resolved.
Show resolved Hide resolved

res = AdvancedPS.advance!(trace, false)
@test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1
@test res ≈ -log(2)

# Catch broken copy, espetially for RNG / VarInfo
newtrace = AdvancedPS.fork(trace)
res2 = AdvancedPS.advance!(trace)
@test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 2
@test DynamicPPL.get_num_produce(newtrace.model.f.varinfo) == 1
end

@turing_testset "fork" begin
@model function normal()
a ~ Normal(0, 1)
3 ~ Normal(a, 2)
b ~ Normal(a, 1)
1.5 ~ Normal(b, 2)
a, b
end
vi = DynamicPPL.VarInfo()
sampler = Sampler(PG(10))
model = normal()

trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())

newtrace = AdvancedPS.forkr(trace)
# Catch broken replay mechanism
@test AdvancedPS.advance!(trace) ≈ AdvancedPS.advance!(newtrace)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AbstractMCMC
using AdvancedMH
using AdvancedPS
using Clustering
using Distributions
using Distributions.FillArrays
Expand Down