diff --git a/RELEASES.md b/RELEASES.md index 387527987..b062dbb80 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,7 @@ + The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533) + New API for Gromov-Wasserstein solvers with `ot.solve_gromov` function (PR #536) + New LP solvers from scipy used by default for LP barycenter (PR #537) ++ Upgraded unbalanced OT solvers for more flexibility (PR #539) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) @@ -19,13 +20,13 @@ ## 0.9.1 *August 2023* -This new release contains several new features and bug fixes. +This new release contains several new features and bug fixes. New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), new solvers for entropic FGW barycenters, novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). Finally we added in the submodule `ot.gaussian` and `ot.da` new loss and mapping estimators for the Gaussian Gromov-Wasserstein that can be used as a fast alternative to GW and estimates linear mappings between unregistered spaces that can potentially have different size (See the update [linear mapping example](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) for an illustration). We also provide a new solver for the [Entropic Wasserstein Component Analysis](https://pythonot.github.io/master/auto_examples/others/plot_EWCA.html) that is a generalization of the celebrated PCA taking into account the local neighborhood of the samples. We also now have a new solver in `ot.smooth` for the [sparsity-constrained OT (last plot)](https://pythonot.github.io/master/auto_examples/plot_OT_1D_smooth.html) that can be used to find regularized OT plans with sparsity constraints. Finally we have a first multi-marginal solver for regular 1D distributions with a Monge loss (see [here](https://pythonot.github.io/master/auto_examples/others/plot_dmmot.html)). -The documentation and testings have also been updated. We now have nearly 95% code coverage with the tests. The documentation has been updated and some examples have been streamlined to build more quickly and avoid timeout problems with CircleCI. We also added an optional CI on GPU for the master branch and approved PRs that can be used when a GPU runner is online. +The documentation and testings have also been updated. We now have nearly 95% code coverage with the tests. The documentation has been updated and some examples have been streamlined to build more quickly and avoid timeout problems with CircleCI. We also added an optional CI on GPU for the master branch and approved PRs that can be used when a GPU runner is online. Many other bugs and issues have been fixed and we want to thank all the contributors, old and new, who made this release possible. More details below. @@ -76,9 +77,9 @@ Many other bugs and issues have been fixed and we want to thank all the contribu *April 2023* This new release contains so many new features and bug fixes since 0.8.2 that we -decided to make it a new minor release at 0.9.0. +decided to make it a new minor release at 0.9.0. -The release contains many new features. First we did a major +The release contains many new features. First we did a major update of all Gromov-Wasserstein solvers that brings up to 30% gain in computation time (see PR #431) and allows the GW solvers to work on non symmetric matrices. It also brings novel solvers for the very @@ -94,7 +95,7 @@ barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_fre and the [Generalized Wasserstein barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_generalized_free_support_barycenter.html#sphx-glr-auto-examples-barycenters-plot-generalized-free-support-barycenter-py). A new differentiable solver for OT across spaces that provides OT plans -between samples and features simultaneously and +between samples and features simultaneously and called [Co-Optimal Transport](https://pythonot.github.io/master/auto_examples/others/plot_COOT.html) has also been implemented. Finally we began working on OT between Gaussian distributions and @@ -147,7 +148,7 @@ when implementing new solvers but we encourage users to play with it. Finally, in addition to those many new this release fixes 20 issues (some long standing) and we want to thank all the contributors who made this release so big. More details below. - + #### New features - Added feature to (Fused) Gromov-Wasserstein solvers inherited from `ot.optim` to support relative and absolute loss variations as stopping criterions (PR #431) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 9584d77ca..265006d2c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -5,6 +5,8 @@ # Author: Hicham Janati # Laetitia Chapel +# Quang Huy Tran +# # License: MIT License from __future__ import division @@ -14,8 +16,7 @@ from scipy.optimize import minimize, Bounds from .backend import get_backend -from .utils import list_to_array -# from .utils import unif, dist +from .utils import list_to_array, get_parameter_pair def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, @@ -27,9 +28,10 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) + W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 @@ -56,8 +58,14 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - reg_m: float - Marginal relaxation term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + The entropic balanced OT can be recovered using `reg_m=float("inf")`. + For semi-relaxed case, use either + `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters @@ -157,13 +165,13 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma\geq 0 - where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -186,8 +194,14 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', loss matrix reg : float Entropy regularization term > 0 - reg_m: float - Marginal relaxation term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + The entropic balanced OT can be recovered using `reg_m=float("inf")`. + For semi-relaxed case, use either + `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters @@ -279,9 +293,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 @@ -307,8 +322,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - reg_m: float - Marginal relaxation term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + The entropic balanced OT can be recovered using `reg_m=float("inf")`. + For semi-relaxed case, use either + `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -376,6 +397,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, else: n_hists = 0 + reg_m1, reg_m2 = get_parameter_pair(reg_m) + if log: log = {'err': []} @@ -391,7 +414,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, K = nx.exp(M / (-reg)) - fi = reg_m / (reg_m + reg) + fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 + fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 err = 1. @@ -400,9 +424,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, vprev = v Kv = nx.dot(K, v) - u = (a / Kv) ** fi + u = (a / Kv) ** fi_1 Ktu = nx.dot(K.T, u) - v = (b / Ktu) ** fi + v = (b / Ktu) ** fi_2 if (nx.any(Ktu == 0.) or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) @@ -461,9 +485,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 stabilization as proposed in :ref:`[10] `: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 @@ -490,8 +515,14 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 loss matrix reg : float Entropy regularization term > 0 - reg_m: float - Marginal relaxation term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + The entropic balanced OT can be recovered using `reg_m=float("inf")`. + For semi-relaxed case, use either + `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). tau : float threshold for max value in u or v for log scaling numItermax : int, optional @@ -559,6 +590,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 else: n_hists = 0 + reg_m1, reg_m2 = get_parameter_pair(reg_m) + if log: log = {'err': []} @@ -575,26 +608,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # print(reg) K = nx.exp(-M / reg) - fi = reg_m / (reg_m + reg) + fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 + fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 cpt = 0 err = 1. alpha = nx.zeros(dim_a, type_as=M) beta = nx.zeros(dim_b, type_as=M) + ones_a = nx.ones(dim_a, type_as=M) + ones_b = nx.ones(dim_b, type_as=M) + while (err > stopThr and cpt < numItermax): uprev = u vprev = v Kv = nx.dot(K, v) - f_alpha = nx.exp(- alpha / (reg + reg_m)) - f_beta = nx.exp(- beta / (reg + reg_m)) + f_alpha = nx.exp(- alpha / (reg + reg_m1)) if reg_m1 != float("inf") else ones_a + f_beta = nx.exp(- beta / (reg + reg_m2)) if reg_m2 != float("inf") else ones_b if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] - u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + u = ((a / (Kv + 1e-16)) ** fi_1) * f_alpha Ktu = nx.dot(K.T, u) - v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + v = ((b / (Ktu + 1e-16)) ** fi_2) * f_beta absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True @@ -1037,7 +1074,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, raise ValueError("Unknown method '%s'." % method) -def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, +def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1045,8 +1082,10 @@ def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + s.t. \gamma \geq 0 @@ -1068,8 +1107,14 @@ def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, Unnormalized histogram of dimension `dim_b` M : array-like (dim_a, dim_b) loss matrix - reg_m: float - Marginal relaxation term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg : float, optional (default = 0) + Entropy regularization term >= 0. + By default, solve the unregularized problem div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1097,12 +1142,12 @@ def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2) - array([[0.3 , 0. ], - [0. , 0.07]]) - >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2) - array([[0.25, 0. ], - [0. , 0. ]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2) + array([[0.45, 0. ], + [0. , 0.34]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2) + array([[0.4, 0. ], + [0. , 0.1]]) .. _references-regpath: @@ -1116,6 +1161,7 @@ def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, ot.lp.emd : Unregularized OT ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT """ + M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) @@ -1131,30 +1177,35 @@ def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, else: G = G0 + reg_m1, reg_m2 = get_parameter_pair(reg_m) + if log: log = {'err': [], 'G': []} if div == 'kl': - K = nx.exp(M / - reg_m / 2) + sum_r = reg + reg_m1 + reg_m2 + r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r + K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) elif div == 'l2': - K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2, - nx.zeros((dim_a, dim_b), type_as=M)) + K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * a[:, None] * b[None, :] - M + K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) else: warnings.warn("The div parameter should be either equal to 'kl' or \ 'l2': it has been set to 'kl'.") div = 'kl' - K = nx.exp(M / - reg_m / 2) + sum_r = reg + reg_m1 + reg_m2 + r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r + K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) for i in range(numItermax): Gprev = G if div == 'kl': - u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16)) - v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16)) - G = G * K * u[:, None] * v[None, :] + G = K * G**(r1 + r2) / (nx.sum(G, 1, keepdims=True)**r1 * nx.sum(G, 0, keepdims=True)**r2 + 1e-16) elif div == 'l2': - Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16 - G = G * K / Gd + Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ + reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 + G = K * G / Gd err = nx.sqrt(nx.sum((G - Gprev) ** 2)) if log: @@ -1172,7 +1223,7 @@ def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, return G -def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, +def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1180,8 +1231,9 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) s.t. \gamma \geq 0 @@ -1204,8 +1256,14 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, Unnormalized histogram of dimension `dim_b` M : array-like (dim_a, dim_b) loss matrix - reg_m: float - Marginal relaxation term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg : float, optional (default = 0) + Entropy regularization term >= 0. + By default, solve the unregularized problem div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1234,10 +1292,10 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) - 0.25 - >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) - 0.57 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2) + 0.8 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2) + 1.79 References ---------- @@ -1249,7 +1307,7 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0, + _, log_mm = mm_unbalanced(a, b, M, reg_m, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True) @@ -1259,7 +1317,7 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, return log_mm['cost'] -def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'): +def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ return the loss function (scipy.optimize compatible) for regularized unbalanced OT @@ -1268,7 +1326,7 @@ def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'): m, n = M.shape def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) + return np.sum(p * np.log(p / q + 1e-16)) - p.sum() + q.sum() def reg_l2(G): return np.sum((G - a[:, None] * b[None, :])**2) / 2 @@ -1280,10 +1338,10 @@ def reg_kl(G): return kl(G, a[:, None] * b[None, :]) def grad_kl(G): - return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1 + return np.log(G / (a[:, None] * b[None, :]) + 1e-16) def reg_entropy(G): - return kl(G, 1) + return np.sum(G * np.log(G + 1e-16)) def grad_entropy(G): return np.log(G + 1e-16) + 1 @@ -1299,16 +1357,19 @@ def grad_entropy(G): grad_reg_fun = grad_l2 def marg_l2(G): - return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2) + return reg_m1 * 0.5 * np.sum((G.sum(1) - a)**2) + \ + reg_m2 * 0.5 * np.sum((G.sum(0) - b)**2) def grad_marg_l2(G): - return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b)) + return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), (G.sum(0) - b)) def marg_kl(G): - return kl(G.sum(1), a) + kl(G.sum(0), b) + return reg_m1 * kl(G.sum(1), a) + reg_m2 * kl(G.sum(0), b) def grad_marg_kl(G): - return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1) + return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) if regm_div == 'kl': regm_fun = marg_kl @@ -1321,10 +1382,10 @@ def _func(G): G = G.reshape((m, n)) # compute loss - val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G) + val = np.sum(G * M) + reg * reg_fun(G) + regm_fun(G) # compute gradient - grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G) + grad = M + reg * grad_reg_fun(G) + grad_regm_fun(G) return val, grad.ravel() @@ -1339,9 +1400,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T) - \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 @@ -1364,13 +1425,16 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, M : array-like (dim_a, dim_b) loss matrix reg: float - regularization term (>=0) - reg_m: float - Marginal relaxation term >= 0 + regularization term >=0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) - reg_div: string, optional + regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) G0: array-like (dim_a, dim_b) @@ -1386,8 +1450,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, Returns ------- - ot_distance : array-like - the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1398,10 +1462,12 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) - 0.25 - >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) - 0.57 + >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) + array([[0.45, 0. ], + [0. , 0.34]]) + >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) + array([[0.4, 0. ], + [0. , 0.1]]) References ---------- @@ -1413,18 +1479,22 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ + + M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) M0 = M - # convert to humpy + # convert to numpy a, b, M = nx.to_numpy(a, b, M) + reg_m1, reg_m2 = get_parameter_pair(reg_m) + if G0 is not None: G0 = nx.to_numpy(G0) else: G0 = np.zeros(M.shape) - _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div) + _func = _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) diff --git a/ot/utils.py b/ot/utils.py index 8cbb0db25..4efcb225e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -492,6 +492,36 @@ def get_coordinate_circle(x): return x_t +def get_parameter_pair(parameter): + r"""Extract a pair of parameters from a given parameter + Used in unbalanced OT and COOT solvers + to handle marginal regularization and entropic regularization. + + Parameters + ---------- + parameter : float or indexable object + nx : backend object + + Returns + ------- + param_1 : float + param_2 : float + """ + + if isinstance(parameter, float) or isinstance(parameter, int): + param_1, param_2 = parameter, parameter + elif len(parameter) == 1: + param_1, param_2 = parameter[0], parameter[0] + else: + if len(parameter) > 2: + raise ValueError("Parameter must be either a scalar, \ + or an indexable object of length 1 or 2.") + else: + param_1, param_2 = parameter[0], parameter[1] + + return param_1, param_2 + + class deprecated(object): r"""Decorator to mark a function or class as deprecated. diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 24e5bc427..272794cb8 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -2,9 +2,11 @@ # Author: Hicham Janati # Laetitia Chapel +# Quang Huy Tran # # License: MIT License + import itertools import numpy as np import ot @@ -71,6 +73,50 @@ def test_unbalanced_convergence(nx, method): np.testing.assert_allclose(G_np, nx.to_numpy(G)) +@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) +def test_unbalanced_relaxation_parameters(nx, method, reg_m): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + M = ot.dist(x, x) + epsilon = 1. + + a, b, M = nx.from_numpy(a, b, M) + + # options for reg_m + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx_reg_m = reg_m * nx.ones(1) + list_options = [nx_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + loss, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, + method=method, log=True, verbose=True + ) + + for opt in list_options: + loss_opt, log_opt = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=opt, + method=method, log=True, verbose=True + ) + + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT @@ -111,8 +157,6 @@ def test_unbalanced_multiple_inputs(nx, method): np.testing.assert_allclose( nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) - assert len(loss) == b.shape[1] - def test_stabilized_vs_sinkhorn(nx): # test if stable version matches sinkhorn @@ -292,9 +336,11 @@ def test_implemented_methods(nx): @pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) def test_lbfgsb_unbalanced(nx, reg_div, regm_div): - rng = np.random.RandomState(42) - xs = rng.randn(5, 2) - xt = rng.randn(6, 2) + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) M = ot.dist(xs, xt) @@ -310,7 +356,46 @@ def test_lbfgsb_unbalanced(nx, reg_div, regm_div): np.testing.assert_allclose(G, nx.to_numpy(Gb)) -def test_mm_convergence(nx): +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + + reg_m = 10 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + np1_reg_m = reg_m * np.ones(1) + np2_reg_m = reg_m * np.ones(2) + + list_options = [np1_reg_m, np2_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + G = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, reg_m=reg_m, + reg_div=reg_div, regm_div=regm_div, + log=False, verbose=False) + + for opt in list_options: + G0 = ot.unbalanced.lbfgsb_unbalanced( + a, b, M, 1, reg_m=opt, reg_div=reg_div, + regm_div=regm_div, log=False, verbose=False + ) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_convergence(nx, div): n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -324,38 +409,110 @@ def test_mm_convergence(nx): reg_m = 100 a, b, M = nx.from_numpy(a_np, b_np, M) - G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', - verbose=False, log=True) - loss_kl = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True) + G, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, + verbose=False, log=True) + loss = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div=div, verbose=True) ) - G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', - verbose=False, log=True) # check if the marginals come close to the true ones when large reg - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G), 0), b_np, atol=1e-03) # check if mm_unbalanced2 returns the correct loss - np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, - atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) # check in case no histogram is provided a_np, b_np = np.array([]), np.array([]) a, b = nx.from_numpy(a_np, b_np) - G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False) - G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False) - np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl)) - np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2)) + G_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_null), nx.to_numpy(G)) # test when G0 is given G0 = ot.emd(a, b, M) G0_np = nx.to_numpy(G0) reg_m = 10000 - G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False) - G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False) - np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05) - np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05) + G = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G), atol=1e-05) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_relaxation_parameters(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + + reg_m = 100 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx1_reg_m = reg_m * nx.ones(1) + nx2_reg_m = reg_m * nx.ones(2) + + list_options = [nx1_reg_m, nx2_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=False, log=True) + loss_0 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=True) + ) + + for opt in list_options: + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=opt, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=opt, + reg=reg, div=div, verbose=True) + ) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + +def test_mm_wrong_divergence(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div="kl", verbose=False, log=True) + loss_0 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div="kl", verbose=True) + ) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div="wrong_div", verbose=False, log=True) + loss_1 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div="wrong_div", verbose=True) + ) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5)