From fe20bc6f5d2051e57ada85644e4f8303fbf46bdf Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Mon, 6 Nov 2023 10:44:09 +0100 Subject: [PATCH] [WIP] Add new features to unbalanced solvers (#551) * add new features to unbalanced solvers * add new features to unbalanced solvers * fix bug in test * remove stab_sinkhorn * remove kl * fix bug in lbfgsb_unbalanced * fix bug in lbfgsb_unbalanced * fix bug in KL in sinkhorn_unbalanced * edit release.md * fix test * add test and rearrange arguments * fix test * fix test * fix test * fix bug in test * fix bug in doctest * fix bug in doctest * add test for more coverage --- RELEASES.md | 163 ++------------------- ot/unbalanced.py | 303 +++++++++++++++++++++++++--------------- test/test_unbalanced.py | 202 ++++++++++++++++++++++++--- 3 files changed, 386 insertions(+), 282 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index cdc986624..97834cdfe 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,8 +13,7 @@ + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) -+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) -+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) ++ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) @@ -36,143 +35,13 @@ Many other bugs and issues have been fixed and we want to thank all the contribu #### New features -- Gaussian Gromov Wasserstein loss and mapping (PR #498) -- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488) -- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483) -- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) -- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459) -- Add tests on GPU for master branch and approved PR (PR #473) -- Add `median` method to all inherited classes of `backend.Backend` (PR #472) -- Update tests for macOS and Windows, speedup documentation (PR #484) -- Added Proximal Point algorithm to solve GW problems via a new parameter `solver="PPA"` in `ot.gromov.entropic_gromov_wasserstein` + examples (PR #455) -- Added features `warmstart` and `kwargs` in `ot.gromov.entropic_gromov_wasserstein` to respectively perform warmstart on dual potentials and pass parameters to `ot.sinkhorn` (PR #455) -- Added sinkhorn projection based solvers for FGW `ot.gromov.entropic_fused_gromov_wasserstein` and entropic FGW barycenters + examples (PR #455) -- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455) -- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455) -- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455) -- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486) -- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454) -#### Closed issues - -- Fix gromov conventions (PR #497) -- Fix change in scipy API for `cdist` (PR #487) -- More permissive check_backend (PR #494) -- Fix circleci-redirector action and codecov (PR #460) -- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457) -- Major documentation cleanup (PR #462, PR #467, PR #475) -- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) -- Faster Bures-Wasserstein distance with NumPy backend (PR #468) -- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) -- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (PR #474) -- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472) -- Fix pression error on marginal sums and (Issue #429, PR #496) - -#### New Contributors -* @kachayev made their first contribution in PR #462 -* @liutianlin0121 made their first contribution in PR #459 -* @francois-rozet made their first contribution in PR #468 -* @framunoz made their first contribution in PR #472 -* @SoniaMaz8 made their first contribution in PR #483 -* @tomMoral made their first contribution in PR #494 -* @12hengyu made their first contribution in PR #454 - -## 0.9.0 -*April 2023* - -This new release contains so many new features and bug fixes since 0.8.2 that we -decided to make it a new minor release at 0.9.0. - -The release contains many new features. First we did a major -update of all Gromov-Wasserstein solvers that brings up to 30% gain in -computation time (see PR #431) and allows the GW solvers to work on non symmetric -matrices. It also brings novel solvers for the very -efficient [semi-relaxed GW problem -](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_fgw.html#sphx-glr-auto-examples-gromov-plot-semirelaxed-fgw-py) -that can be used to find the best re-weighting for one of the distributions. We -also now have fast and differentiable solvers for [Wasserstein on the circle](https://pythonot.github.io/master/auto_examples/plot_compute_wasserstein_circle.html#sphx-glr-auto-examples-plot-compute-wasserstein-circle-py) and -[sliced Wasserstein on the -sphere](https://pythonot.github.io/master/auto_examples/backends/plot_ssw_unif_torch.html#sphx-glr-auto-examples-backends-plot-ssw-unif-torch-py). -We are also very happy to provide new OT barycenter solvers such as the [Free -support Sinkhorn -barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_free_support_sinkhorn_barycenter.html#sphx-glr-auto-examples-barycenters-plot-free-support-sinkhorn-barycenter-py) -and the [Generalized Wasserstein -barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_generalized_free_support_barycenter.html#sphx-glr-auto-examples-barycenters-plot-generalized-free-support-barycenter-py). -A new differentiable solver for OT across spaces that provides OT plans -between samples and features simultaneously and -called [Co-Optimal -Transport](https://pythonot.github.io/master/auto_examples/others/plot_COOT.html) -has also been implemented. Finally we began working on OT between Gaussian distributions and -now provide differentiable estimation for the Bures-Wasserstein [divergence](https://pythonot.github.io/master/gen_modules/ot.gaussian.html#ot.gaussian.bures_wasserstein_distance) and -[mappings](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-linear-mapping-py). - -Another important first step toward POT 1.0 is the -implementation of a unified API for OT solvers with introduction of [`ot.solve`](https://pythonot.github.io/master/all.html#ot.solve) -function that can solve (depending on parameters) exact, regularized and -unbalanced OT and return a new -[`OTResult`](https://pythonot.github.io/master/gen_modules/ot.utils.html#ot.utils.OTResult) -object. The idea behind this new API is to facilitate exploring different solvers -with just a change of parameter and get a more unified API for them. We will keep -the old solvers API for power users but it will be the preferred way to solve -problems starting from release 1.0.0. -We provide below some examples of use for the new function and how to -recover different aspects of the solution (OT plan, full loss, linear part of the -loss, dual variables) : -```python -#Solve exact ot -sol = ot.solve(M) - -# get the results -G = sol.plan # OT plan -ot_loss = sol.value # OT value (full loss for regularized and unbalanced) -ot_loss_linear = sol.value_linear # OT value for linear term np.sum(sol.plan*M) -alpha, beta = sol.potentials # dual potentials - -# direct plan and loss computation -G = ot.solve(M).plan -ot_loss = ot.solve(M).value - -# OT exact with marginals a/b -sol2 = ot.solve(M, a, b) - -# regularized and unbalanced OT -sol_rkl = ot.solve(M, a, b, reg=1) # KL regularization -sol_rl2 = ot.solve(M, a, b, reg=1, reg_type='L2') -sol_ul2 = ot.solve(M, a, b, unbalanced=10, unbalanced_type='L2') -sol_rkl_ukl = ot.solve(M, a, b, reg=10, unbalanced=10) # KL + KL - -``` -The function is fully compatible with backends and will be implemented for -different types of distribution support (empirical distributions, grids) and OT -problems (Gromov-Wasserstein) in the new releases. This new API is not yet -presented in the kickstart part of the documentation as there is a small change -that it might change -when implementing new solvers but we encourage users to play with it. - -Finally, in addition to those many new this release fixes 20 issues (some long -standing) and we want to thank all the contributors who made this release so -big. More details below. - - -#### New features -- Added feature to (Fused) Gromov-Wasserstein solvers inherited from `ot.optim` to support relative and absolute loss variations as stopping criterions (PR #431) -- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #431) -- Added semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #431) -- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) -- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) -- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) -- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) -- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449) -- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437) -- Added parameters method in `ot.da.SinkhornTransport` (PR #440) -- `ot.dr` now uses the new Pymanopt API and POT is compatible with current - Pymanopt (PR #443) -- Added CO-Optimal Transport solver + examples (PR #447) -- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448) +- Backend version of `ot.partial` and `ot.smooth` (PR #388) +- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #437) #### Closed issues @@ -200,11 +69,9 @@ PR #413) - Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) -- Fixed an issue with the documentation gallery section (PR #444) -- Fixed issues with cuda variables for `line_search_armijo` and `entropic_gromov_wasserstein` (Issue #445, #PR 446) + ## 0.8.2 -*April 2022* This releases introduces several new notable features. The less important but most exiting one being that we now have a logo for the toolbox (color @@ -348,7 +215,7 @@ a [Generative Network (GAN)](https://PythonOT.github.io/auto_examples/backends/plot_wass2_gan_torch.html), for a [sliced Wasserstein gradient flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html) -and [optimizing the Gromov-Wasserstein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite +and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite slow at the moment, we strongly recommend for Jax users to use the [OTT toolbox](https://github.com/google-research/ott) when possible. As a result of this new feature, @@ -360,7 +227,7 @@ Pointwise Gromov Wasserstein](https://PythonOT.github.io/auto_examples/gromov/plot_gromov.html#compute-gw-with-a-scalable-stochastic-method-with-any-loss-function), Sinkhorn in log space with `method='sinkhorn_log'`, [Projection Robust Wasserstein](https://PythonOT.github.io/gen_modules/ot.dr.html?highlight=robust#ot.dr.projection_robust_wasserstein), -ans [debiased Sinkhorn barycenters](https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html). +ans [deviased Sinkorn barycenters](https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html). This release will also simplify the installation process. We have now a `pyproject.toml` that defines the build dependency and POT should now build even @@ -501,7 +368,7 @@ are coming for the next versions. #### Closed issues -- Add JMLR paper to the readme and Mathieu Blondel to the Acknowledgments (PR +- Add JMLR paper to the readme and Mathieu Blondel to the Acknoledgments (PR #231, #232) - Bug in Unbalanced OT example (Issue #127) - Clean Cython output when calling setup.py clean (Issue #122) @@ -509,7 +376,7 @@ are coming for the next versions. - EMD dimension mismatch (Issue #114, Fixed in PR #116) - 2D barycenter bug for non square images (Issue #124, fixed in PR #132) - Bad value in EMD 1D (Issue #138, fixed in PR #139) -- Log bugs for Gromov-Wasserstein solver (Issue #107, fixed in PR #108) +- Log bugs for Gromov-Wassertein solver (Issue #107, fixed in PR #108) - Weight issues in barycenter function (PR #106) ## 0.6.0 @@ -540,9 +407,9 @@ a solver for [Unbalanced OT barycenters](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_barycenter_1D.ipynb). A new variant of Gromov-Wasserstein divergence called [Fused Gromov-Wasserstein](https://pot.readthedocs.io/en/latest/all.html?highlight=fused_#ot.gromov.fused_gromov_wasserstein) -has been also contributed with examples of use on [structured +has been also contributed with exemples of use on [structured data](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb) and -computing [barycenters of labeled +computing [barycenters of labeld graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb). @@ -603,7 +470,7 @@ and [free support](https://github.com/rflamary/POT/blob/master/notebooks/plot_fr implementation of entropic OT. POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of -the unmaintained cudamat. Note that while we tried to keep changes to the +the unmaintained cudamat. Note that while we tried to keed changes to the minimum, the OTDA classes were deprecated. If you are happy with the cudamat implementation, we recommend you stay with stable release 0.4 for now. @@ -627,7 +494,7 @@ and new POT contributors (you can see the list in the [readme](https://github.co * Stochastic OT in the dual and semi-dual (PR #52 and PR #62) * Free support barycenters (PR #56) * Speed-up Sinkhorn function (PR #57 and PR #58) -* Add convolutional Wasserstein barycenters for 2D images (PR #64) +* Add convolutional Wassersein barycenters for 2D images (PR #64) * Add Greedy Sinkhorn variant (Greenkhorn) (PR #66) * Big ot.gpu update with cupy implementation (instead of un-maintained cudamat) (PR #67) @@ -678,7 +545,7 @@ This release contains a lot of contribution from new contributors. * new notebooks for emd computation and Wasserstein Discriminant Analysis * relocate notebooks * update documentation -* clean_zeros(a,b,M) for removing zeros in sparse distributions +* clean_zeros(a,b,M) for removimg zeros in sparse distributions * GPU implementations for sinkhorn and group lasso regularization @@ -686,7 +553,7 @@ This release contains a lot of contribution from new contributors. *7 Apr 2017* * New dimensionality reduction method (WDA) -* Efficient method emd2 returns only transport (in parallel if several histograms given) +* Efficient method emd2 returns only tarnsport (in paralell if several histograms given) @@ -727,4 +594,4 @@ It provides the following solvers: * Optimal transport for domain adaptation with group lasso regularization * Conditional gradient and Generalized conditional gradient for regularized OT. -Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. \ No newline at end of file +Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 265006d2c..73667b324 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -19,7 +19,8 @@ from .utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', + reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -39,7 +40,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -67,8 +68,17 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -100,9 +110,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-unbalanced: References @@ -134,21 +143,21 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: @@ -156,8 +165,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - numItermax=1000, stopThr=1e-6, verbose=False, - log=False, **kwargs): + reg_type="entropy", warmstart=None, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -175,7 +184,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -203,8 +212,17 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameterss + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -226,12 +244,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', -------- >>> import ot + >>> import numpy as np >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) - array([0.31912866]) - + >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8) + 0.31912858 .. _references-sinkhorn-unbalanced2: References @@ -258,34 +276,60 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) else: - raise ValueError('Unknown method %s.' % method) + if method.lower() == 'sinkhorn': + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan @@ -304,7 +348,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -330,6 +374,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -361,9 +414,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-knopp-unbalanced: References @@ -404,15 +456,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, 1), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, 1), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - K = nx.exp(M / (-reg)) + if reg_type == "kl": + K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -474,9 +532,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, return u[:, None] * K * v[None, :] -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, tau=1e5, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -496,7 +555,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -523,6 +582,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). tau : float threshold for max value in u or v for log scaling numItermax : int, optional @@ -555,9 +623,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-stabilized-unbalanced: References @@ -597,16 +664,24 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - # print(reg) - K = nx.exp(-M / reg) + if reg_type == "kl": + log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :] + M0 = M - reg * log_ab + else: + M0 = M + + K = nx.exp(-M0 / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -641,7 +716,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 else: alpha = alpha + reg * nx.log(nx.max(u)) beta = beta + reg * nx.log(nx.max(v)) - K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + K = nx.exp((alpha[:, None] + beta[None, :] - M0) / reg) v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) @@ -687,7 +762,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 nx.log(M + 1e-100)[:, :, None] + logu[:, None, :] + logv[None, :, :] - - M[:, :, None] / reg, + - M0[:, :, None] / reg, axis=(0, 1) ) res = nx.exp(res) @@ -697,7 +772,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 return res else: # return OT matrix - ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) if log: return ot_matrix, log else: @@ -1074,7 +1149,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, raise ValueError("Unknown method '%s'." % method) -def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1084,7 +1159,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1094,6 +1169,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1113,8 +1189,11 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg : float, optional (default = 0) - Entropy regularization term >= 0. + Regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1172,36 +1251,33 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b - if G0 is None: - G = a[:, None] * b[None, :] - else: - G = G0 + G = a[:, None] * b[None, :] if G0 is None else G0 + c = a[:, None] * b[None, :] if c is None else c reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': [], 'G': []} - if div == 'kl': - sum_r = reg + reg_m1 + reg_m2 - r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) - elif div == 'l2': - K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * a[:, None] * b[None, :] - M - K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) - else: + if div not in ["kl", "l2"]: warnings.warn("The div parameter should be either equal to 'kl' or \ 'l2': it has been set to 'kl'.") div = 'kl' + + if div == 'kl': sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) + K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) + elif div == 'l2': + K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M + K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) for i in range(numItermax): Gprev = G if div == 'kl': - G = K * G**(r1 + r2) / (nx.sum(G, 1, keepdims=True)**r1 * nx.sum(G, 0, keepdims=True)**r2 + 1e-16) + Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 + G = K * G**(r1 + r2) / Gd elif div == 'l2': Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 @@ -1223,7 +1299,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return G -def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1233,7 +1309,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1243,6 +1319,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1264,6 +1341,9 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1307,7 +1387,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, reg=reg, div=div, G0=G0, + _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True) @@ -1317,7 +1397,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return log_mm['cost'] -def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): +def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ return the loss function (scipy.optimize compatible) for regularized unbalanced OT @@ -1326,25 +1406,25 @@ def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='k m, n = M.shape def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) - p.sum() + q.sum() + return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) def reg_l2(G): - return np.sum((G - a[:, None] * b[None, :])**2) / 2 + return np.sum((G - c)**2) / 2 def grad_l2(G): - return G - a[:, None] * b[None, :] + return G - c def reg_kl(G): - return kl(G, a[:, None] * b[None, :]) + return kl(G, c) def grad_kl(G): - return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + return np.log(G / c + 1e-16) def reg_entropy(G): - return np.sum(G * np.log(G + 1e-16)) + return np.sum(G * np.log(G + 1e-16)) - np.sum(G) def grad_entropy(G): - return np.log(G + 1e-16) + 1 + return np.log(G + 1e-16) if reg_div == 'kl': reg_fun = reg_kl @@ -1392,7 +1472,7 @@ def _func(G): return _func -def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, +def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. @@ -1400,7 +1480,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -1412,6 +1492,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize @@ -1426,6 +1507,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, loss matrix reg: float regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, @@ -1433,7 +1517,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (quadratic). regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1482,19 +1567,15 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) - M0 = M + # convert to numpy a, b, M = nx.to_numpy(a, b, M) + G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) + c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) reg_m1, reg_m2 = get_parameter_pair(reg_m) - - if G0 is not None: - G0 = nx.to_numpy(G0) - else: - G0 = np.zeros(M.shape) - - _func = _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div, regm_div) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 272794cb8..7007e336b 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,8 +14,8 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(nx, method): +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -25,29 +25,32 @@ def test_unbalanced_convergence(nx, method): # make dists unbalanced b = ot.utils.unif(n) * 1.5 - M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + epsilon = 1. reg_m = 1. - a, b, M = nx.from_numpy(a, b, M) - - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method=method, - log=True, - verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=True, verbose=True + ) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, verbose=True )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16) - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - + if reg_type == "entropy": + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + elif reg_type == "kl": + log_ab = loga[:, None] + logb[None, :] + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -65,15 +68,109 @@ def test_unbalanced_convergence(nx, method): a, b = nx.from_numpy(a_np, b_np) G = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, + method=method, reg_type=reg_type, verbose=True ) G_np = ot.unbalanced.sinkhorn_unbalanced( - a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, + method=method, reg_type=reg_type, verbose=True ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) -@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_warmstart(nx, method, reg_type): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + G0, log0 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=None, log=True, verbose=True + ) + loss0 = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=None, verbose=True + ) + + dim_a, dim_b = M.shape + warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + ) + loss = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, verbose=True + ) + + _, log_emd = ot.lp.emd(a, b, M, log=True) + warmstart1 = (log_emd["u"], log_emd["v"]) + G1, log1 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True + ) + loss1 = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart1, verbose=True + ) + + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) + + +@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False])) +def test_sinkhorn_unbalanced2(nx, method, reg_type, log): + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=False, verbose=True + )) + + res = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=log, verbose=True + ) + loss0 = res[0] if log else res + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + + +@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")])) def test_unbalanced_relaxation_parameters(nx, method, reg_m): # test generalized sinkhorn for unbalanced OT n = 100 @@ -117,7 +214,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -135,11 +232,10 @@ def test_unbalanced_multiple_inputs(nx, method): a, b, M = nx.from_numpy(a, b, M) - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method=method, - log=True, - verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, method=method, + log=True, verbose=True) + # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) @@ -394,6 +490,31 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_reference_measure(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + c = a[:, None] * b[None, :] + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + @pytest.mark.parametrize("div", ["kl", "l2"]) def test_mm_convergence(nx, div): n = 100 @@ -483,6 +604,41 @@ def test_mm_relaxation_parameters(nx, div): np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_reference_measure(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + c = a[:, None] * b[None, :] + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + def test_mm_wrong_divergence(nx): n = 100