Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for observed variables in the trace #7641

Merged
merged 5 commits into from
Jan 20, 2025

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Jan 12, 2025

Description

This introduces the enhancement discussed in #7225

This makes it easier to use sample_posterior_predictive in model factory workflows

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7641.org.readthedocs.build/en/7641/

@zaxtax zaxtax requested a review from ricardoV94 January 12, 2025 16:38
@zaxtax zaxtax force-pushed the pull_observes_from_idata_in_pps branch from ca7c8f8 to 567fb2a Compare January 12, 2025 16:47
Copy link

codecov bot commented Jan 12, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.78%. Comparing base (e6767ab) to head (81eef66).
Report is 5 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7641   +/-   ##
=======================================
  Coverage   92.77%   92.78%           
=======================================
  Files         107      107           
  Lines       18178    18185    +7     
=======================================
+ Hits        16865    16873    +8     
+ Misses       1313     1312    -1     
Files with missing lines Coverage Δ
pymc/sampling/forward.py 96.26% <100.00%> (+0.11%) ⬆️

... and 1 file with indirect coverage changes

@zaxtax zaxtax force-pushed the pull_observes_from_idata_in_pps branch from 567fb2a to 662595b Compare January 12, 2025 21:22
Comment on lines 484 to 492
# test that trace is used in ppc
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert "a" in ppc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this in its own test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also make the test more stringent. Test that only the variables that you want are actually included in the trace. Also add a case where one node is conditionally dependent on the trace.observed_data so that you see that the auto added variables include conditional nodes.

@@ -817,6 +819,8 @@ def sample_posterior_predictive(
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
if observed_data is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, the the observed_dependent_deterministics above is not going to work if these variables are not observed in the model.

That happens with auto-imputation models, which I assume the as_model wrapper won't handle correctly either because the models are different depending on whether you pass data or not.

Just something to keep in mind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this. You'll have to adapt observed_dependent_deterministics to also accept a list of extra variables that will depend on your observed_data

@ricardoV94 ricardoV94 changed the title Check for observed variables in the trace as well as the model Check for observed variables in the trace Jan 14, 2025
@zaxtax
Copy link
Contributor Author

zaxtax commented Jan 14, 2025 via email

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second Ricardo with his comments. Just a few changes and I think this will be ok to merge.

@@ -817,6 +819,8 @@ def sample_posterior_predictive(
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
if observed_data is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this. You'll have to adapt observed_dependent_deterministics to also accept a list of extra variables that will depend on your observed_data

@@ -817,6 +819,8 @@ def sample_posterior_predictive(
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
if observed_data is not None:
vars_ += [model[x] for x in observed_data if x in model]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also add and if x not in vars_ to the list comprehension.

Comment on lines 484 to 492
# test that trace is used in ppc
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert "a" in ppc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also make the test more stringent. Test that only the variables that you want are actually included in the trace. Also add a case where one node is conditionally dependent on the trace.observed_data so that you see that the auto added variables include conditional nodes.

@zaxtax
Copy link
Contributor Author

zaxtax commented Jan 14, 2025 via email

@ricardoV94
Copy link
Member

That helper was created specifically for the deterministic created by automatic imputation that joins the observed and unobserved components, so I'm not too worried about it.

But if you wanted just add a deterministic that's like y + 1, where y is the variable that was observed during sampling

@zaxtax zaxtax force-pushed the pull_observes_from_idata_in_pps branch from f24a55c to 8dc0945 Compare January 19, 2025 16:47
@zaxtax zaxtax force-pushed the pull_observes_from_idata_in_pps branch from a217113 to e895a5c Compare January 19, 2025 16:52
@zaxtax
Copy link
Contributor Author

zaxtax commented Jan 19, 2025

It's not fully clear the best way to make sure the correct conditional nodes are added when the model that produced the trace isn't readily available. So I have written a test we expect to fail in the meanwhile.

@@ -821,6 +824,7 @@ def sample_posterior_predictive(
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
if observed_data is not None:
vars_ += [model[x] for x in observed_data if x in model and x not in vars_]
vars_ += observed_dependent_deterministics(model, vars_)
Copy link
Member

@ricardoV94 ricardoV94 Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to duplicate deterministics of an observed variable, in case there's a mix of observed model variables and implied observed variables from the idata. the + observed_dependent_determininstics) should be called only once after the if branch?

@zaxtax
Copy link
Contributor Author

zaxtax commented Jan 20, 2025

@lucianopaz does this address your concerns?

@@ -540,6 +540,50 @@ def test_normal_scalar_idata(self):
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)
assert ppc["a"].shape == (nchains, ndraws)

def test_external_trace(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this test? The one with det is strictly more comprehensive than this?

@zaxtax zaxtax force-pushed the pull_observes_from_idata_in_pps branch from 4a0aa08 to 2fcf395 Compare January 20, 2025 16:02
@ricardoV94 ricardoV94 merged commit 892c37a into pymc-devs:main Jan 20, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants