Skip to content

Commit

Permalink
Merge pull request #1 from guillaumehu/doc_otplan
Browse files Browse the repository at this point in the history
doc & sigma int
  • Loading branch information
guillaumehu authored Nov 13, 2023
2 parents 21cd0c8 + b12f8fa commit 25c519e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,16 @@ def __init__(self, sigma: float = 0.0):
----------
sigma : float
"""
self.sigma = sigma
self._sigma = sigma

@property
def sigma(self):
if isinstance(self._sigma, float):
return self._sigma
elif isinstance(self._sigma, int):
return float(self._sigma)
else:
raise ValueError("Sigma must be a float or int.")

def compute_mu_t(self, x0, x1, t):
"""
Expand Down
15 changes: 15 additions & 0 deletions torchcfm/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ def __init__(
normalize_cost=False,
**kwargs,
):
r"""Initialize the OTPlanSampler class.
Parameters
----------
method : str
The method used to compute the OT plan. Can be one of "exact", "sinkhorn",
"unbalanced", or "partial".
reg : float (default : 0.05)
Entropic regularization coefficients.
reg_m : float (default : 1.0)
Marginal relaxation term for unbalanced OT (`method='unbalanced'`).
normalize_cost : bool (default : False)
Whether to normalize the cost matrix by its maximum value.
It should be set to `False` when using minibatches.
"""
# ot_fn should take (a, b, M) as arguments where a, b are marginals and
# M is a cost matrix
if method == "exact":
Expand Down

0 comments on commit 25c519e

Please sign in to comment.