-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Maybe add epsilon to RelaxedOneHotCategorical to prevent underflow #3304
Comments
FYI: I also took a stab at fixing the straightthroughcategorical, this could still use some work but it works for me where the previous RelaxedCategoricalStraightThrough would not train as part of an GMM-VAE class RelaxedQuantizeCategorical(torch.autograd.Function):
temperature = None # Default temperature
epsilon = 1e-10 # Default epsilon
@staticmethod
def set_temperature(new_temperature):
RelaxedQuantizeCategorical.temperature = new_temperature
@staticmethod
def set_epsilon(new_epsilon):
RelaxedQuantizeCategorical.epsilon = new_epsilon
@staticmethod
def forward(ctx, soft_value):
temperature = float(RelaxedQuantizeCategorical.temperature)
epsilon = RelaxedQuantizeCategorical.epsilon
uniforms = clamp_probs(
torch.rand(soft_value.shape, dtype=soft_value.dtype, device=soft_value.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (soft_value + gumbels) / temperature
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + epsilon # Use the class variable epsilon
hard_value = (outs / outs.sum(1, keepdim=True)).log()
hard_value._unquantize = soft_value
return hard_value
@staticmethod
def backward(ctx, grad):
return grad
class ExpRelaxedCategoricalStraightThrough(Distribution):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
constraints.real_vector
) # The true support is actually a submanifold of this.
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None, epsilon=1e-10):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
RelaxedQuantizeCategorical.set_temperature(temperature)
RelaxedQuantizeCategorical.set_epsilon(epsilon)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new.temperature = self.temperature
new._categorical = self._categorical.expand(batch_shape)
super(ExpRelaxedCategorical, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def param_shape(self):
return self._categorical.param_shape
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
def rsample(self, sample_shape=torch.Size()):
outs=RelaxedQuantizeCategorical.apply(self.logits)
return outs
def log_prob(self, value):
value = getattr(value, "_unquantize", value)
K = self._categorical._num_events
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
score = logits
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
return score
class SafeAndRelaxedOneHotCategoricalStraightThrough(TransformedDistribution,TorchDistributionMixin):
#Don't understand why these were broken (doesn't call straighthrough rsample in pyro)?
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategoricalStraightThrough(
temperature, probs, logits, validate_args=validate_args
)
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super().expand(batch_shape, _instance=new)
@property
def temperature(self):
return self.base_dist.temperature
@property
def logits(self):
return self.base_dist.logits
@property
def probs(self):
return self.base_dist.probs |
Hi @mtvector, I think our general design principle with distributions is to make them hackable with decent defaults. In this case I'd lean towards letting users add their own epsilon in a custom distribution class. In my own projects I often have one or two custom distributions for each data science project. What do you think of a simple patched distribution, just for your project? from pyro.distributions import ExpRelaxedCategorical
class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
epsilon = 1e-10
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + self.epsilon # prevent underflow
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs Actually I often find that (1) clamping is safer than adding, and (2) it's best to use class SafeExpRelaxedCategorical2(ExpRelaxedCategorical):
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs.clamp(min=torch.finfo(outs.dtype).tiny)
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs WDYT? |
Hi @fritzo,
Giving the error due to underflow:
You're right about the fix, for instance your first resolves the issue with the underflow in a more elegant way than what I proposed:
Which gives no error, like my SafeAndRelaxedOneHotCategoricalStraightThrough above So, yeah, it seems like the default for RelaxedOneHotCategorical should use one of these SafeExpRelaxedCategorical bases you've proposed here? |
I've noticed that pyro.distributions.RelaxedOneHotCategorical tends to underflow pretty dramatically if you decrease the temperature below 0.3 or so with many categories. I've been adding a slight modification to the rsample function of the ExpRelaxedCategorical class it's built on. Just wanted to post this in case you want to consider this (maybe hacky) fix to make this distribution work with pyro support constraints.
modified from here https://github.com/pytorch/pytorch/blob/main/torch/distributions/relaxed_categorical.py :
The text was updated successfully, but these errors were encountered: