Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New API ot.solve_sample #563

Merged
merged 32 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f49f6b4
new file for lr sinkhorn
laudavid Oct 24, 2023
3c4b50f
lr sinkhorn, solve_sample, OTResultLazy
laudavid Oct 24, 2023
3034e57
add test functions + small modif lr_sin/solve_sample
laudavid Oct 25, 2023
085863a
add import to __init__
laudavid Oct 26, 2023
9becafc
modify low rank, remove solve_sample,OTResultLazy
laudavid Nov 3, 2023
855234d
pull from master
laudavid Nov 3, 2023
58576a3
solve_sample + test functions
laudavid Nov 3, 2023
ed1b22d
remove low rank from branch
laudavid Nov 3, 2023
6ea251c
new file for lr sinkhorn
laudavid Oct 24, 2023
965e4d6
lr sinkhorn, solve_sample, OTResultLazy
laudavid Oct 24, 2023
fd5e26d
add test functions + small modif lr_sin/solve_sample
laudavid Oct 25, 2023
3df3b77
add import to __init__
laudavid Oct 26, 2023
c7e899f
remove lowrank + OTResultLazy
laudavid Nov 3, 2023
9403851
clean ot.solve_sample and remve lazy test cause not ilplemented yet
rflamary Nov 7, 2023
f8a7c0b
Merge branch 'master' into newapi_solve_sample
rflamary Nov 7, 2023
411f302
add factored and gaussian solvers
rflamary Nov 7, 2023
a1fbce3
Merge branch 'master' into newapi_solve_sample
rflamary Nov 8, 2023
5812de2
workin lazy sinkhorn with lazy tensor returned
rflamary Nov 9, 2023
e7f5bee
Merge branch 'newapi_solve_sample' of github.com:PythonOT/POT into ne…
rflamary Nov 9, 2023
65b5643
Merge branch 'master' into newapi_solve_sample
rflamary Nov 9, 2023
e2e76cf
Merge branch 'master' into newapi_solve_sample
rflamary Nov 10, 2023
26c6f49
stuff
rflamary Nov 14, 2023
c9a9035
Merge branch 'newapi_solve_sample' of github.com:PythonOT/POT into ne…
rflamary Nov 14, 2023
46bf259
Merge branch 'master' into newapi_solve_sample
rflamary Nov 14, 2023
c0884d8
merge master
rflamary Nov 14, 2023
4b88982
update documùentation
rflamary Nov 14, 2023
d4e60b0
beter documentation
rflamary Nov 14, 2023
b1182c8
pep8
rflamary Nov 14, 2023
f54ec56
big update tests
rflamary Nov 15, 2023
129d905
debug small test
rflamary Nov 15, 2023
c098a15
remarques cédri
rflamary Nov 15, 2023
d3f5bf3
small stuff
rflamary Nov 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
+ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551)
+ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
5 changes: 3 additions & 2 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import solvers
from . import gaussian


# OT functions
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
binary_search_circle, wasserstein_circle,
Expand All @@ -50,7 +51,7 @@
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov
from .solvers import solve, solve_gromov, solve_sample

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -65,7 +66,7 @@
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve', 'solve_gromov',
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
48 changes: 47 additions & 1 deletion ot/bregman/_empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,56 @@

import warnings

from ..utils import dist, list_to_array, unif
from ..utils import dist, list_to_array, unif, LazyTensor
from ..backend import get_backend

from ._sinkhorn import sinkhorn, sinkhorn2


def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None):
r""" Get a LazyTensor of Sinkhorn solution from the dual potentials

The returned LazyTensor is
:math:`\mathbf{T} = exp( \mathbf{f} \mathbf{1}_b^\top + \mathbf{1}_a \mathbf{g}^\top - \mathbf{C}/reg)`, where :math:`\mathbf{C}` is the pairwise metric matrix between samples :math:`\mathbf{X}_a` and :math:`\mathbf{X}_b`.

Parameters
----------
X_a : array-like, shape (n_samples_a, dim)
samples in the source domain
X_b : array-like, shape (n_samples_b, dim)
samples in the target domain
f : array-like, shape (n_samples_a,)
First dual potentials (log space)
g : array-like, shape (n_samples_b,)
Second dual potentials (log space)
metric : str, default='sqeuclidean'
Metric used for the cost matrix computation
reg : float, default=1e-1
Regularization term >0
nx : Backend(), default=None
Numerical backend used


Returns
-------
T : LazyTensor
Sinkhorn solution tensor
"""

if nx is None:
nx = get_backend(X_a, X_b, f, g)

shape = (X_a.shape[0], X_b.shape[0])

def func(i, j, X_a, X_b, f, g, metric, reg):
C = dist(X_a[i], X_b[j], metric=metric)
return nx.exp(f[i, None] + g[None, j] - C / reg)

T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, metric=metric, reg=reg)

return T


def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
log=False, warn=True, warmstart=None, **kwargs):
Expand Down Expand Up @@ -198,6 +242,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if log:
dict_log["u"] = f
rflamary marked this conversation as resolved.
Show resolved Hide resolved
dict_log["v"] = g
dict_log["niter"] = i_ot
dict_log["lazy_plan"] = get_sinkhorn_lazytensor(X_s, X_t, f, g, metric, reg)
return (f, g, dict_log)
else:
return (f, g)
Expand Down
2 changes: 1 addition & 1 deletion ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
Cs12 = nx.sqrtm(Cs)

B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
W = nx.sqrt(nx.norm(ms - mt)**2 + B)
W = nx.sqrt(nx.maximum(nx.norm(ms - mt)**2 + B, 0))

if log:
log = {}
Expand Down
Loading
Loading