diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index cea5e36..53c53f3 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -1,7 +1,7 @@ import math import warnings from functools import partial -from typing import Optional +from typing import Optional, Union import numpy as np import ot as pot @@ -18,6 +18,7 @@ def __init__( reg: float = 0.05, reg_m: float = 1.0, normalize_cost: bool = False, + num_threads: Union[int, str] = 1, warn: bool = True, ) -> None: """Initialize the OTPlanSampler class. @@ -36,13 +37,16 @@ def __init__( normalizes the cost matrix so that the maximum cost is 1. Helps stabilize Sinkhorn-based solvers. Should not be used in the vast majority of cases. + num_threads: int or str, optional + number of threads to use for the "exact" OT solver. If "max", uses + the maximum number of threads. warn: bool, optional if True, raises a warning if the algorithm does not converge """ # ot_fn should take (a, b, M) as arguments where a, b are marginals and # M is a cost matrix if method == "exact": - self.ot_fn = pot.emd + self.ot_fn = partial(pot.emd, numThreads=num_threads) elif method == "sinkhorn": self.ot_fn = partial(pot.sinkhorn, reg=reg) elif method == "unbalanced":