Skip to content

Commit

Permalink
Handle nested PrefixContext (#787)
Browse files Browse the repository at this point in the history
* Prefix varnames appropriately inside check_model_and_trace

* Fix values_as_in_model as well

* Add test for check_model with manual prefix

* Add values_as_in_model tests

* Add tests for prefix nesting

* Bump Project.toml
  • Loading branch information
penelopeysm authored Jan 27, 2025
1 parent 00e7ee3 commit 29a6c7e
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.34.1"
version = "0.34.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
15 changes: 9 additions & 6 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ end

const PREFIX_SEPARATOR = Symbol(".")

# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here
function PrefixContext{PrefixInner}(
context::PrefixContext{PrefixOuter}
) where {PrefixInner,PrefixOuter}
Expand All @@ -273,13 +274,15 @@ function PrefixContext{PrefixInner}(
end
end

function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
if @generated
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn)))
else
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
end
# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
return prefix(
childcontext(ctx), VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
)
end
prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn)
prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn
prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn)

"""
prefix(model::Model, x)
Expand Down
23 changes: 12 additions & 11 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,50 +239,51 @@ function DynamicPPL.setchildcontext(context::DebugContext, child)
end

function record_varname!(context::DebugContext, varname::VarName, dist)
if haskey(context.varnames_seen, varname)
prefixed_varname = prefix(context, varname)
if haskey(context.varnames_seen, prefixed_varname)
if context.error_on_failure
error("varname $varname used multiple times in model")
error("varname $prefixed_varname used multiple times in model")
else
@warn "varname $varname used multiple times in model"
@warn "varname $prefixed_varname used multiple times in model"
end
context.varnames_seen[varname] += 1
context.varnames_seen[prefixed_varname] += 1
else
# We need to check:
# 1. Does this `varname` subsume any of the other keys.
# 2. Does any of the other keys subsume `varname`.
vns = collect(keys(context.varnames_seen))
# Is `varname` subsumed by any of the other keys?
idx_parent = findfirst(Base.Fix2(subsumes, varname), vns)
idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns)
if idx_parent !== nothing
varname_parent = vns[idx_parent]
if context.error_on_failure
error(
"varname $(varname_parent) used multiple times in model (subsumes $varname)",
"varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)",
)
else
@warn "varname $(varname_parent) used multiple times in model (subsumes $varname)"
@warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)"
end
# Update count of parent.
context.varnames_seen[varname_parent] += 1
else
# Does `varname` subsume any of the other keys?
idx_child = findfirst(Base.Fix1(subsumes, varname), vns)
idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns)
if idx_child !== nothing
varname_child = vns[idx_child]
if context.error_on_failure
error(
"varname $(varname_child) used multiple times in model (subsumed by $varname)",
"varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)",
)
else
@warn "varname $(varname_child) used multiple times in model (subsumed by $varname)"
@warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)"
end

# Update count of child.
context.varnames_seen[varname_child] += 1
end
end

context.varnames_seen[varname] = 1
context.varnames_seen[prefixed_varname] = 1
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false
is_extracting_values(::IsLeaf, ::AbstractContext) = false

function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
return setindex!(context.values, copy(value), vn)
return setindex!(context.values, copy(value), prefix(context, vn))
end

function broadcast_push!(context::ValuesAsInModelContext, vns, values)
Expand Down
20 changes: 20 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ end
@test getoptic(vn_prefixed) === getoptic(vn)
end

@testset "nested within arbitrary context stacks" begin
vn = @varname(x[1])
ctx1 = PrefixContext{:a}(DefaultContext())
ctx2 = SamplingContext(ctx1)
ctx3 = PrefixContext{:b}(ctx2)
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
vn_prefixed1 = prefix(ctx1, vn)
vn_prefixed2 = prefix(ctx2, vn)
vn_prefixed3 = prefix(ctx3, vn)
vn_prefixed4 = prefix(ctx4, vn)
@test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x")
@test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x")
@test DynamicPPL.getsym(vn_prefixed3) == Symbol("a.b.x")
@test DynamicPPL.getsym(vn_prefixed4) == Symbol("a.b.x")
@test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn)
@test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn)
@test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn)
@test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn)
end

context = DynamicPPL.PrefixContext{:prefix}(SamplingContext())
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Sample with the context.
Expand Down
9 changes: 9 additions & 0 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@
end
model = ModelOuterWorking()
@test check_model(model; error_on_failure=true)

# With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785
@model function ModelOuterWorking2()
x1 ~ to_submodel(prefix(ModelInner(), :a), false)
x2 ~ to_submodel(prefix(ModelInner(), :b), false)
return (x1, x2)
end
model = ModelOuterWorking2()
@test check_model(model; error_on_failure=true)
end

@testset "subsumes (x then x[1])" begin
Expand Down
21 changes: 21 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "Prefixing" begin
@model inner() = x ~ Normal()

@model function outer_auto_prefix()
a ~ to_submodel(inner(), true)
b ~ to_submodel(inner(), true)
return nothing
end
@model function outer_manual_prefix()
a ~ to_submodel(prefix(inner(), :a), false)
b ~ to_submodel(prefix(inner(), :b), false)
return nothing
end

for model in (outer_auto_prefix(), outer_manual_prefix())
vi = VarInfo(model)
vns = Set(keys(values_as_in_model(model, false, vi)))
@test vns == Set([@varname(var"a.x"), @varname(var"b.x")])
end
end
end

@testset "Erroneous model call" begin
Expand Down

4 comments on commit 29a6c7e

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. From a user-facing perspective, this means that for models which use manually prefixed submodels, e.g.
using DynamicPPL, Distributions

@model inner() = x ~ Normal()

@model function outer()
  x1 ~ to_submodel(prefix(inner(), :a), false)
  x2 ~ to_submodel(prefix(inner(), :b), false)
end

will: (1) no longer error when sampling due to check_model_and_trace; and (2) contain both submodel's variables in the resulting chain (the behaviour before this patch was that the second x would override the first x).

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123809

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.34.2 -m "<description of version>" 29a6c7ec0cd62d3a4d1dc18a304d5e4d1e024cfb
git push origin v0.34.2

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. From a user-facing perspective, this means that for models which use manually prefixed submodels, e.g.
using DynamicPPL, Distributions

@model inner() = x ~ Normal()

@model function outer()
  x1 ~ to_submodel(prefix(inner(), :a), false)
  x2 ~ to_submodel(prefix(inner(), :b), false)
end

will: (1) no longer error when sampling due to check_model_and_trace; and (2) contain both submodel's variables in the resulting chain (the behaviour before this patch was that the second x would override the first x).

  • More broadly, implemented a general prefix(ctx::AbstractContext, ::VarName) which traverses the context tree in ctx to apply all necessary prefixes. This was a necessary step in fixing the above issues, but it also means that prefix is now capable of handling context trees with e.g. multiple prefixes at different levels of nesting.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/123809

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.34.2 -m "<description of version>" 29a6c7ec0cd62d3a4d1dc18a304d5e4d1e024cfb
git push origin v0.34.2

Please sign in to comment.