diff --git a/RELEASES.md b/RELEASES.md index 951b5f327..30c248044 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,6 +8,7 @@ + Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507) + The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533) + The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533) ++ New API for Gromov-Wasserstein solvers with `ot.solve_gromov` function (PR #536) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/__init__.py b/ot/__init__.py index 44e87eabe..f16b6fcfc 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -50,7 +50,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve +from .solvers import solve, solve_gromov # utils functions from .utils import dist, unif, tic, toc, toq @@ -65,7 +65,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', + 'factored_optimal_transport', 'solve', 'solve_gromov', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/da.py b/ot/da.py index 8764268f0..3628db51e 100644 --- a/ot/da.py +++ b/ot/da.py @@ -2274,6 +2274,7 @@ class NearestBrenierPotential(BaseTransport): ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data """ + def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None): self.strongly_convex_constant = strongly_convex_constant self.gradient_lipschitz_constant = gradient_lipschitz_constant diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 6dc705949..146e82631 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -330,6 +330,7 @@ def entropic_gromov_wasserstein2( learning for graph matching and node embedding. In International Conference on Machine Learning (ICML), 2019. """ + T, logv = entropic_gromov_wasserstein( C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, solver, warmstart, verbose, log=True, **kwargs) @@ -815,12 +816,19 @@ def entropic_fused_gromov_wasserstein2( (ICML). 2019. """ + + nx = get_backend(M, C1, C2) + T, logv = entropic_fused_gromov_wasserstein( M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter, tol, solver, warmstart, verbose, log=True, **kwargs) logv['T'] = T + lin_term = nx.sum(T * M) + logv['quad_loss'] = (logv['fgw_dist'] - (1 - alpha) * lin_term) + logv['lin_loss'] = lin_term * (1 - alpha) + if log: return logv['fgw_dist'], logv else: diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 20373f33b..d5e4c7f13 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -582,6 +582,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', fgw_dist = log_fgw['fgw_dist'] log_fgw['T'] = T + # compute separate terms for gradients and log + lin_term = nx.sum(T * M) + log_fgw['quad_loss'] = (fgw_dist - (1 - alpha) * lin_term) + log_fgw['lin_loss'] = lin_term * (1 - alpha) + gw_term = log_fgw['quad_loss'] / alpha + if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) @@ -591,8 +597,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', log_fgw['v'] - nx.mean(log_fgw['v']), alpha * gC1, alpha * gC2, (1 - alpha) * T)) else: - lin_term = nx.sum(T * M) - gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), (log_fgw['u'] - nx.mean(log_fgw['u']), log_fgw['v'] - nx.mean(log_fgw['v']), diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index b36a81c75..0b905c1fa 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -488,6 +488,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo q = nx.sum(T, 0) srfgw_dist = log_fgw['srfgw_dist'] log_fgw['T'] = T + log_fgw['lin_loss'] = nx.sum(M * T) * (1 - alpha) + log_fgw['quad_loss'] = srfgw_dist - log_fgw['lin_loss'] if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) @@ -979,7 +981,9 @@ def df(G): if log: qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) - log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G) + log['lin_loss'] = nx.sum(M * G) * (1 - alpha) + log['quad_loss'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + log['srfgw_dist'] = log['lin_loss'] + log['quad_loss'] return G, log else: return G diff --git a/ot/solvers.py b/ot/solvers.py index bba2734e5..0313cf588 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -14,6 +14,14 @@ from .bregman import sinkhorn_log from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual +from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, + entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, + semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_gromov_wasserstein2) +from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 + +#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, @@ -51,11 +59,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional - Type of unbalanced penalization unction :math:`U` either "KL", "L2", "TV", by default "KL" + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) + Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional @@ -345,3 +353,498 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, value_linear=value_linear, plan=plan, status=status, backend=nx) return res + + +def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, + alpha=0.5, reg=None, + reg_type="entropy", unbalanced=None, unbalanced_type='KL', + n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, + verbose=False): + r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object + + The function solves the following optimization problem: + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and + `reg_type`. By default ``reg=None`` and there is no regularization. The + unbalanced marginal penalization can be selected with `unbalanced` + (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` + and the function solves the exact optimal transport problem (respecting the + marginals). + + Parameters + ---------- + Ca : array_like, shape (dim_a, dim_a) + Cost matrix in the source domain + Cb : array_like, shape (dim_b, dim_b) + Cost matrix in the target domain + M : array_like, shape (dim_a, dim_b), optional + Linear cost matrix for Fused Gromov-Wasserstein (default is None). + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + loss : str, optional + Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` + symmetric : bool, optional + Use symmetric version of the Gromov-Wasserstein problem, by default None + tests whether the matrices are symmetric or True/False to avoid the test. + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R`, by default "entropy" (only used when + ``reg!=None``) + alpha : float, optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for + Gromov problem (when M is not provided). By default ``alpha=None`` + corresponds to ``alpha=1`` for Gromov problem (``M==None``) and + ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT), Not implemented yet + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", + "partial", by default "KL" but note that it is not implemented yet. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + method : str, optional + Method for solving the problem, for entropic problems "PGD" is projected + gradient descent and "PPA" for proximal point, default None for + automatic selection ("PGD"). + max_iter : int, optional + Maximum number of iterations, by default None (default values in each + solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + tol : float, optional + Tolerance for solution precision, by default None (default values in + each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + The following methods are available for solving the Gromov-Wasserstein + problem: + + - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): + + .. math:: + \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb) # uniform weights + res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights + res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss + + plan = res.plan # GW plan + value = res.value # GW value + + - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value + + - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value (including regularization) + + - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW + res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW + + plan = res.plan # FGW plan + right_marginal = res.marginal_b # right marginal of the plan + + - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} + + \mathbf{T}^T \mathbf{1} \leq \mathbf{b} + + \mathbf{T} \geq 0 + + \mathbf{1}^T\mathbf{T}\mathbf{1} = m + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 + + + .. _references-solve-gromov: + References + ---------- + + .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric + approach to object matching. Foundations of computational mathematics, + 11(4), 417-487. + + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), + Gromov-Wasserstein averaging of kernel and distance matrices + International Conference on Machine Learning (ICML). + + .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. + (2019). Optimal Transport for structured data with application on graphs + Proceedings of the 36th International Conference on Machine Learning + (ICML). + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, + Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and + applications on graphs. International Conference on Learning + Representations (ICLR), 2022. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport + with Applications on Positive-Unlabeled Learning, Advances in Neural + Information Processing Systems (NeurIPS), 2020. + + """ + + # detect backend + nx = get_backend(Ca, Cb, M, a, b) + + # create uniform weights if not given + if a is None: + a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] + if b is None: + b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] + + # default values for solutions + potentials = None + value = None + value_linear = None + value_quad = None + plan = None + status = None + + loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} + + if loss.lower() not in loss_dict.keys(): + raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) + loss_fun = loss_dict[loss.lower()] + + if reg is None or reg == 0: # exact OT + + if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT + + if M is None or alpha == 1: # Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log['T'] + potentials = (log['u'], log['v']) + + elif alpha == 0: # Wasserstein problem + + # default values for EMD solver + if max_iter is None: + max_iter = 1000000 + + value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + + value = value_linear + potentials = (log['u'], log['v']) + plan = log['G'] + status = log["warning"] if log["warning"] is not None else 'Converged' + value_quad = 0 + + else: # Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + potentials = (log['u'], log['v']) + + elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + + if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # Semi relaxed Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + elif unbalanced_type.lower() in ['partial']: # Partial OT + + if M is None: # Partial Gromov-Wasserstein problem + + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError('Partial GW mass given in reg is too large')) + if loss.lower() != 'l2': + raise (NotImplementedError('Partial GW only implemented with L2 loss')) + if symmetric is not None: + raise (NotImplementedError('Partial GW only implemented with symmetric=True')) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + + value_quad = value + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + + raise (NotImplementedError('Partial FGW not implemented yet')) + + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT + + raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) + + else: + raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + + else: # regularized OT + + if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT + + if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = 'PGD' + + value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + plan = log['T'] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + # potentials = (log['log_u'], log['log_v']) #TODO + + elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose) + + value_linear = nx.sum(M * plan) + value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) + potentials = (log['log_u'], log['log_v']) + + elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = 'PGD' + + value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + # potentials = (log['u'], log['v']) + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: + raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + + elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + + if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + plan = log['T'] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: # Entropic Semi-relaxed FGW problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + elif unbalanced_type.lower() in ['partial']: # Partial OT + + if M is None: # Partial Gromov-Wasserstein problem + + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError('Partial GW mass given in reg is too large')) + if loss.lower() != 'l2': + raise (NotImplementedError('Partial GW only implemented with L2 loss')) + if symmetric is not None: + raise (NotImplementedError('Partial GW only implemented with symmetric=True')) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + + value_quad = value + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + + raise (NotImplementedError('Partial entropic FGW not implemented yet')) + + else: # unbalanced AND regularized OT + + raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, + value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) + + return res diff --git a/ot/utils.py b/ot/utils.py index 72df4294f..8cbb0db25 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -731,11 +731,12 @@ class UndefinedParameter(Exception): class OTResult: - def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None): + def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None): self._potentials = potentials self._value = value self._value_linear = value_linear + self._value_quad = value_quad self._plan = plan self._log = log self._sparse_plan = sparse_plan @@ -828,7 +829,8 @@ def lazy_plan(self): @property def value(self): - """Full transport cost, including possible regularization terms.""" + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" if self._value is not None: return self._value else: @@ -842,6 +844,14 @@ def value_linear(self): else: raise NotImplementedError() + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + if self._value_quad is not None: + return self._value_quad + else: + raise NotImplementedError() + # Marginal constraints ------------------------- @property def marginals(self): diff --git a/test/test_gromov.py b/test/test_gromov.py index 846e69f2b..06f843a4a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -17,7 +17,7 @@ def test_gromov(nx): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -80,7 +80,7 @@ def test_gromov(nx): def test_asymmetric_gromov(nx): - n_samples = 30 # nb samples + n_samples = 20 # nb samples rng = np.random.RandomState(0) C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) @@ -124,7 +124,7 @@ def test_asymmetric_gromov(nx): def test_gromov_dtype_device(nx): # setup - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -160,7 +160,7 @@ def test_gromov_dtype_device(nx): @pytest.mark.skipif(not tf, reason="tf not installed") def test_gromov_device_tf(): nx = ot.backend.TensorflowBackend() - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) @@ -192,7 +192,7 @@ def test_gromov_device_tf(): def test_gromov2_gradients(): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -257,7 +257,7 @@ def test_gromov2_gradients(): def test_gw_helper_backend(nx): - n_samples = 20 # nb samples + n_samples = 10 # nb samples mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 1]]) @@ -301,7 +301,7 @@ def line_search(cost, G, deltaG, Mi, cost_G): pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), ]) def test_gw_helper_validation(loss_fun): - n_samples = 20 # nb samples + n_samples = 10 # nb samples mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 1]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) @@ -548,7 +548,7 @@ def test_entropic_gromov_dtype_device(nx): @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_fgw(nx): - n_samples = 10 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -613,7 +613,7 @@ def test_entropic_fgw(nx): def test_entropic_proximal_fgw(nx): - n_samples = 10 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -678,7 +678,7 @@ def test_entropic_proximal_fgw(nx): def test_asymmetric_entropic_fgw(nx): - n_samples = 10 # nb samples + n_samples = 5 # nb samples rng = np.random.RandomState(0) C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) diff --git a/test/test_solvers.py b/test/test_solvers.py index b792aca94..f0f5b638f 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -12,11 +12,17 @@ import ot -lst_reg = [None, 1.0] +lst_reg = [None, 1] lst_reg_type = ['KL', 'entropy', 'L2'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] +lst_reg_type_gromov = ['entropy'] +lst_gw_losses = ['L2', 'KL'] +lst_unbalanced_type_gromov = ['KL', 'semirelaxed', 'partial'] +lst_unbalanced_gromov = [None, 0.9] +lst_alpha = [0, 0.4, 0.9, 1] + def assert_allclose_sol(sol1, sol2): @@ -107,7 +113,7 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): assert_allclose_sol(sol, solb) except NotImplementedError: - pass + pytest.skip("Not implemented") def test_solve_not_implemented(nx): @@ -131,3 +137,121 @@ def test_solve_not_implemented(nx): # pairs of incompatible divergences with pytest.raises(NotImplementedError): ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') + + +def test_solve_gromov(nx): + + np.random.seed(0) + + n_samples_s = 3 + n_samples_t = 5 + + Ca = np.random.rand(n_samples_s, n_samples_s) + Ca = (Ca + Ca.T) / 2 + + Cb = np.random.rand(n_samples_t, n_samples_t) + Cb = (Cb + Cb.T) / 2 + + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = np.random.rand(n_samples_s, n_samples_t) + + sol0 = ot.solve_gromov(Ca, Cb) # GW + sol = ot.solve_gromov(Ca, Cb, a=a, b=b) # GW + sol0_fgw = ot.solve_gromov(Ca, Cb, M) # FGW + + # check some attributes + sol.potentials + sol.marginals + + assert_allclose_sol(sol0, sol) + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov(Cax, Cbx, a=ax, b=bx) # GW + solx_fgw = ot.solve_gromov(Cax, Cbx, Mx) # FGW + + assert_allclose_sol(sol, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type,alpha,loss", itertools.product(lst_reg, lst_reg_type_gromov, lst_unbalanced_gromov, lst_unbalanced_type_gromov, lst_alpha, lst_gw_losses)) +def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha, loss): + + np.random.seed(0) + + n_samples_s = 3 + n_samples_t = 5 + + Ca = np.random.rand(n_samples_s, n_samples_s) + Ca = (Ca + Ca.T) / 2 + + Cb = np.random.rand(n_samples_t, n_samples_t) + Cb = (Cb + Cb.T) / 2 + + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = np.random.rand(n_samples_s, n_samples_t) + + try: + + sol0 = ot.solve_gromov(Ca, Cb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW + sol0_fgw = ot.solve_gromov(Ca, Cb, M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov(Cax, Cbx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW + solx_fgw = ot.solve_gromov(Cax, Cbx, Mx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW + + solx.value_quad + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) + + except NotImplementedError: + pytest.skip("Not implemented") + + +def test_solve_gromov_not_implemented(nx): + + np.random.seed(0) + + n_samples_s = 3 + n_samples_t = 5 + + Ca = np.random.rand(n_samples_s, n_samples_s) + Ca = (Ca + Ca.T) / 2 + + Cb = np.random.rand(n_samples_t, n_samples_t) + Cb = (Cb + Cb.T) / 2 + + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = np.random.rand(n_samples_s, n_samples_t) + + Ca, Cb, M, a, b = nx.from_numpy(Ca, Cb, M, a, b) + + # test not implemented and check raise + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, loss='weird loss') + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, unbalanced=1, unbalanced_type='cryptic divergence') + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, reg=1, reg_type='cryptic divergence') + + # detect partial not implemented and error detect in value + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=1.5) + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.5) + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) diff --git a/test/test_utils.py b/test/test_utils.py index 787fbe68a..40324518e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -386,7 +386,8 @@ def test_OTResult(): 'sparse_plan', 'status', 'value', - 'value_linear'] + 'value_linear', + 'value_quad'] for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at)