-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve user experience in inference with partially observed discrete variables #3255
Comments
@fritzo It seems that you wrote the following comment in
I ran the tests in If you point me in the right direction I can write additional tests and work on a pull request for this issue. Thanks! |
Hi @gui11aume, responding to your last question: the best tests we have of For example, this incorrect change to diff --git a/pyro/distributions/score_parts.py b/pyro/distributions/score_parts.py
index 15d39156..bb758d82 100644
--- a/pyro/distributions/score_parts.py
+++ b/pyro/distributions/score_parts.py
@@ -25,6 +25,6 @@ class ScoreParts(
:type mask: torch.BoolTensor or None
"""
log_prob = scale_and_mask(self.log_prob, scale, mask)
- score_function = self.score_function # not scaled
+ score_function = scale_and_mask(self.score_function, scale, mask)
entropy_term = scale_and_mask(self.entropy_term, scale, mask)
return ScoreParts(log_prob, score_function, entropy_term) results in a test failure in around one second:
|
Thanks @fritzo! That's so useful to understand how the code works. I'll start from there and think of ways to address the issues without breaking everything. By the way, I am reading your work of the past few years and I am so impressed. I don't say this very often, but you are an example to look up to. |
I think I understand why the What you call Loosely speaking, if we estimate the gradient of the ELBO with only half of the terms, then we have to multiply As far as I understand, scaling and masking are processed together in - score_function = self.score_function # not scaled
+ score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled I believe that this is because there are no tests to check the gradient when masking is active. I'll see if I can write some for this case, along the lines of those in |
@gui11aume yes, your explanation sounds right. I'm sorry I don't recall why masking behaves differently from scaling; I vaguely recall there was a reason, but I forget whether that reason was due to deep mathematics or merely incidental complexity in our implementation. Additional tests would be great! |
Thanks for confirming @fritzo. Below is a very long post, I don't expect anyone to read it. It is mostly here for reference, to keep track of my rationale. Issue 1: Code failure when maskingAfter some thinking, my opinion is that it should be allowed to have sites in the model but not in the guide (the import pyro
import pyro.distributions as dist
import torch
def model(data):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data)
return
def guide(data):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
pyro.sample("z", z_dist)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2.]))
print(pyro.param("loc"), pyro.param("scale"))
# tensor([0.9798], requires_grad=True) tensor([0.7382], requires_grad=True) The inference is correct: in this case the posterior distribution is import pyro
import pyro.distributions as dist
import torch
def model(data, mask):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
return
def guide(data, mask):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
pyro.sample("z", z_dist)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc"), pyro.param("scale"))
# (...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
# tensor([0.9631], requires_grad=True) tensor([0.7807], requires_grad=True) We get a warning but the inference is correct. In the first model, we have only one observation ( Moving on, if a masked variable has no import pyro
import pyro.distributions as dist
import torch
def model(data, mask):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
return
def guide(data, mask):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
z_dist.has_rsample = False # <== Pretend we cannot use the reparametrization trick.
pyro.sample("z", z_dist)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc"), pyro.param("scale"))
# (...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
# warnings.warn(f"Found vars in model but not guide: {bad_sites}")
# Traceback (most recent call last):
# File "tmp3.py", line 20, in <module>
# svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
# File "(...)/pyro/infer/svi.py", line 145, in step
# loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
# File "(...)/pyro/infer/trace_elbo.py", line 141, in loss_and_grads
# loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
# File "(...)/pyro/infer/trace_elbo.py", line 106, in _differentiable_loss_particle
# log_r = _compute_log_r(model_trace, guide_trace)
# File "(...)/pyro/infer/trace_elbo.py", line 27, in _compute_log_r
# log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
# KeyError: 'x_unobserved' The code should have the same behavior as before, masking has nothing to do with # https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/trace_elbo.py#L20C1-L29C17
def _compute_log_r(model_trace, guide_trace):
log_r = MultiFrameTensor()
stacks = get_plate_stacks(model_trace)
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
log_r_term = model_site["log_prob"]
if not model_site["is_observed"]:
log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] # <== This can fail.
log_r.add((stacks[name], log_r_term.detach()))
return log_r # https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/trace_mean_field_elbo.py#L107C1-L114C63
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_site["log_prob_sum"]
else:
guide_site = guide_trace.nodes[name] # <== This can fail.
if is_validation_enabled():
check_fully_reparametrized(guide_site) # https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/tracegraph_elbo.py#L217C1-L223C66
# construct all the reinforce-like terms.
# we include only downstream costs to reduce variance
# optionally include baselines to further reduce variance
for node, downstream_cost in downstream_costs.items():
guide_site = guide_trace.nodes[node] # <== This can fail.
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
score_function = guide_site["score_parts"].score_function Issue 2: Counter-intuitive gradientLet us go back to the second case, where the reparametrization trick is available, and let us try to infer the distribution of the missing value of import pyro
import pyro.distributions as dist
import torch
def model(data, mask):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
return
def guide(data, mask):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
pyro.sample("z", z_dist)
with pyro.plate("data", len(data)):
loc_x = pyro.param("loc_x", lambda: torch.tensor([0., 0.]))
scale_x = pyro.param("scale_x", lambda: torch.tensor([1., 1.]))
with pyro.poutine.mask(mask=~mask):
pyro.sample("x_unobserved", dist.Normal(loc_x, scale_x))
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc_x"), pyro.param("scale_x"))
# tensor([0.0000, 0.9787], requires_grad=True) tensor([1.0000, 1.0632], requires_grad=True) As expected, the warning is gone. We also see that parameters for the observed values are exactly as they were initialized, meaning that for these values, the gradient was 0 throughout, as expected. It is therefore the behavior that should be achieved when the variables have no This can be done with the update mentioned above, where the terms of - score_function = self.score_function # not scaled
+ score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled I have written some new tests and I will open a pull request draft shortly. |
I have opened pull request #3265 with the changes discussed above and some new tests involving gradients with |
Summary
Inference for partially observed discrete variables occasionally produces some counter-intuitive results. Those are not bugs but users may waste a lot of time dealing with them or trying to understand them. The behavior has been tested on Pyro 1.8.5 and 1.8.6.
A simple example with coins
The example below is meant to show in which kind of context the issues appear. It is artificial and has no practical applications, but it is inspired from real examples I stumbled upon. In the model, we flip a fair coin and do not show the result; if it lands 'heads' we flip a coin with bias 0.05; if it lands 'tails' we flip a coin with bias 0.95. We always observe the result of the biased coin (but not which coin was flipped). In the guide, we simply sample the unbiased coin.
The result is correct, the second coin landed 'tails' so the posterior probability that the unbiased coin landed 'tails' is 0.95.
Issue 1: Code failure when masking
If the second coin is sometimes observed, we can introduce an observation mask for the
obs
sample. Let us modify the code and run the same example, i.e., we specify that the second coin landed 'tails' and this is observed.The code fails with the error below.
First there is a warning for missing site in the guide and then a
KeyError
for the same reason. This is counter-intuitive: either guides with missing sites should be allowed (warn only), or they should not (raise an error for every guide with missing sites).The error comes from a part of the code that evaluates the loss using the REINFORCE estimator (i.e., when the reparametrization trick cannot be used, as in the case of discrete random variables). Line 27 in
trace_elbo.py
assumes that every unobserved site in the model also exists in the guide. The user may not be aware that the siteobs_unobserved
is created in the model (but not in the guide) as soon as the argumentobs_mask
is notNone
.The solution is to define the sample
obs_unobserved
in the guide (see how below), but there are barely any mentions of this, so we cannot assume that users will do it. If guides with missing sites are allowed, line 27 intrace_elbo.py
should be replaced with a fail-safe version. Ideally, a message could point users in the right direction if Pyro creates an_unobserved
site that is not in the guide.Issue 2: Counter-intuitive gradient
Now if the unbiased coin is sometimes observed, we can introduce an observation mask for the
unbiased
sample, together with some observations when they are available. As mentioned above, we need to add a site in the guide calledunbiased_unobserved
explaining what to do when the coin is not observed (i.e., sample it as we were doing until now). We have to sample the whole tensor; Pyro will automatically mix in observed and sampled values for us as needed.Some values in
unbiased_unobserved
are sampled for nothing: they will be replaced with the observed values if they are available. In this case, the sampled values have no effect on the inference, but just to be sure, we are going to mask them in the guide to set theirlog_prob
terms to 0. We do this by usingpoutine.mask
where we invert the observation mask with~mask
.In this example, we observed the first two flips of the unbiased coin, but not the third. We set the value to
heads
with0
but this is irrelevant because the value is never used throughout the inference. The inference is correct for the third flip and there is nothing to infer for the first two flips because the values were observed... So why did the values ofpost_p
change from the initial0.5
and what do the current values represent?As far as I understand, the values have no special meaning. There is nothing to infer anyway. So why did they change? Once again, this has to do with the way Pyro evaluates the loss using the REINFORCE estimator. Internally, it keeps track of a
log_prob
term and ascore_function
term for the sites of the guide. Thelog_prob
terms are masked but not thescore_function
terms, so all the values ofunbiased_unobserved
contribute to the gradient, even those that are overwritten by observed values.I don't think that this has side effects, so this is not really a bug. The issue here is that Pyro is difficult enough to debug, and erratic behaviors make it harder. It would help if parameters that have no effect on the inference have gradient 0, so that the user gets alerted when there is an error in the model (e.g., when values that should have no effect on the inference do in fact have an effect).
The text was updated successfully, but these errors were encountered: