Skip to content

Commit

Permalink
Fix bug in computing log likelihood of flow model.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jul 16, 2024
1 parent 46fa20e commit d1791bf
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions grl/generative_models/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@
from tensordict import TensorDict
from torch.distributions import Independent, Normal

from grl.generative_models.diffusion_model import (
DiffusionModel,
EnergyConditionalDiffusionModel,
)
from grl.numerical_methods.numerical_solvers.ode_solver import (
ODESolver,
)
from grl.utils import find_parameters


def compute_likelihood(
model: Union[DiffusionModel, EnergyConditionalDiffusionModel],
model,
x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor],
t: torch.Tensor = None,
condition: Union[torch.Tensor, TensorDict] = None,
Expand Down Expand Up @@ -49,7 +45,7 @@ def compute_likelihood(
"IndependentConditionalFlowModel",
"OptimalTransportConditionalFlowModel",
]:
model_drift = lambda t, x: model.model(t, x, condition)
model_drift = lambda t, x: - model.model(1 - t, x, condition)
model_params = find_parameters(model.model)
elif model.get_type() == "FlowModel":
model_drift = lambda t, x: model.model(t, x, condition)
Expand Down

0 comments on commit d1791bf

Please sign in to comment.