From ed6b60750d8f2b1934b4185f6255610a0dd5791f Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 13 Nov 2023 13:58:56 +0100 Subject: [PATCH 1/2] docstrings OTPlanSampler --- torchcfm/optimal_transport.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index 4f563ba..ebb54d1 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -19,6 +19,21 @@ def __init__( normalize_cost=False, **kwargs, ): + r"""Initialize the OTPlanSampler class. + + Parameters + ---------- + method : str + The method used to compute the OT plan. Can be one of "exact", "sinkhorn", + "unbalanced", or "partial". + reg : float (default : 0.05) + Entropic regularization coefficients. + reg_m : float (default : 1.0) + Marginal relaxation term for unbalanced OT (`method='unbalanced'`). + normalize_cost : bool (default : False) + Whether to normalize the cost matrix by its maximum value. + It should be set to `False` when using minibatches. + """ # ot_fn should take (a, b, M) as arguments where a, b are marginals and # M is a cost matrix if method == "exact": From b12f8fa15ae7d3b38eba569d026175805c84b495 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 13 Nov 2023 14:12:05 +0100 Subject: [PATCH 2/2] property for sigma int/float --- torchcfm/conditional_flow_matching.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 3e518ef..99a9742 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -54,7 +54,16 @@ def __init__(self, sigma: float = 0.0): ---------- sigma : float """ - self.sigma = sigma + self._sigma = sigma + + @property + def sigma(self): + if isinstance(self._sigma, float): + return self._sigma + elif isinstance(self._sigma, int): + return float(self._sigma) + else: + raise ValueError("Sigma must be a float or int.") def compute_mu_t(self, x0, x1, t): """