From d1791bfccc9270853b4523160a5c9c28e68e0aa6 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 16 Jul 2024 21:05:19 +0800 Subject: [PATCH] Fix bug in computing log likelihood of flow model. --- grl/generative_models/metric.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/grl/generative_models/metric.py b/grl/generative_models/metric.py index f54051b..46cd25f 100644 --- a/grl/generative_models/metric.py +++ b/grl/generative_models/metric.py @@ -5,10 +5,6 @@ from tensordict import TensorDict from torch.distributions import Independent, Normal -from grl.generative_models.diffusion_model import ( - DiffusionModel, - EnergyConditionalDiffusionModel, -) from grl.numerical_methods.numerical_solvers.ode_solver import ( ODESolver, ) @@ -16,7 +12,7 @@ def compute_likelihood( - model: Union[DiffusionModel, EnergyConditionalDiffusionModel], + model, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], t: torch.Tensor = None, condition: Union[torch.Tensor, TensorDict] = None, @@ -49,7 +45,7 @@ def compute_likelihood( "IndependentConditionalFlowModel", "OptimalTransportConditionalFlowModel", ]: - model_drift = lambda t, x: model.model(t, x, condition) + model_drift = lambda t, x: - model.model(1 - t, x, condition) model_params = find_parameters(model.model) elif model.get_type() == "FlowModel": model_drift = lambda t, x: model.model(t, x, condition)