From 10717598a7e991e02d9fbc30d3a05b852916ea2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 4 Nov 2023 23:07:34 +0100 Subject: [PATCH] add kl_loss to all semi-relaxed (f)gw solvers (#559) --- RELEASES.md | 1 + ot/gromov/_semirelaxed.py | 80 ++++----- ot/gromov/_utils.py | 27 +++- test/test_gromov.py | 332 +++++++++++++++++++------------------- 4 files changed, 230 insertions(+), 210 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 7c090bef8..cdc986624 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) + Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) ++ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 0b905c1fa..cbfe64ea8 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() arr = [C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -139,7 +136,7 @@ def df(G): return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs) + return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs) if log: res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) @@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) + + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) if log: return srgw, log_srgw @@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() - arr = [M, C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -382,7 +379,7 @@ def df(G): def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs) + G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs) if log: res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) @@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - if isinstance(alpha, int) or isinstance(alpha, float): - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), - (alpha * gC1, alpha * gC2, (1 - alpha) * T)) - else: - lin_term = nx.sum(T * M) - srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), - (alpha * gC1, alpha * gC2, (1 - alpha) * T, - srgw_term - lin_term)) + + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + if isinstance(alpha, int) or isinstance(alpha, float): + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), + (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + lin_term = nx.sum(T * M) + srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), + (alpha * gC1, alpha * gC2, (1 - alpha) * T, + srgw_term - lin_term)) if log: return srfgw_dist, log_fgw @@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, - M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): + M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs): """ Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence. @@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration cost_G : float Value of the cost at `G` - C1 : array-like (ns,ns) - Structure matrix in the source domain. - C2 : array-like (nt,nt) - Structure matrix in the target domain. + C1 : array-like (ns,ns), optional + Transformed Structure matrix in the source domain. + Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed + C2 : array-like (nt,nt), optional + Transformed Structure matrix in the source domain. + Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed ones_p: array-like (ns,1) Array of ones of size ns M : array-like (ns,nt) Cost matrix between the features. reg : float Regularization parameter. + fC2t: array-like (nt,nt), optional + Transformed Structure matrix in the source domain. + Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed. + If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'. alpha_min : float, optional Minimum value for alpha alpha_max : float, optional @@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0) dot = nx.dot(nx.dot(C1, deltaG), C2.T) - C2t_square = C2.T ** 2 - dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square) - dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square) - a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG) - b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + if fC2t is None: + fC2t = C2.T ** 2 + dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t) + dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t) + + a = reg * nx.sum((dot_qdeltaG - dot) * deltaG) + b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) @@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional @@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() arr = [C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional @@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional @@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() arr = [M, C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index d77e44f9e..2c1bda823 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -399,6 +399,19 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): h_2(b) &= 2b + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) Parameters ---------- C1 : array-like, shape (ns, ns) @@ -451,9 +464,19 @@ def h1(a): def h2(b): return 2 * b elif loss_fun == 'kl_loss': - raise NotImplementedError() + def f1(a): + return a * nx.log(a + 1e-15) - a + + def f2(b): + return b + + def h1(a): + return a + + def h2(b): + return nx.log(b + 1e-15) else: - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.") + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), nx.ones((1, C2.shape[0]), type_as=p)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 8870a5023..a71433bb5 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -1941,32 +1941,33 @@ def test_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, - G0=None, alpha_min=0., alpha_max=1.) + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, + G0=None, alpha_min=0., alpha_max=1.) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) # symmetric C1 = 0.5 * (C1 + C1.T) @@ -2025,19 +2026,20 @@ def test_semirelaxed_gromov2_gradients(): if torch.cuda.is_available(): devices.append(torch.device("cuda")) for device in devices: - # semirelaxed solvers do not support gradients over masses yet. - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) + for loss_fun in ['square_loss', 'kl_loss']: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1) + val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1, loss_fun=loss_fun) - val.backward() + val.backward() - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape def test_srgw_helper_backend(nx): @@ -2057,35 +2059,35 @@ def test_srgw_helper_backend(nx): C1 /= C1.max() C2 /= C2.max() - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True) - - # calls with nx=None - constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') - ones_pb = nx.ones(pb.shape[0], type_as=pb) - - def f(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def df(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.gromov.solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None) - # feed the precomputed local optimum Gb to semirelaxed_cg - res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - # check constraints - np.testing.assert_allclose(res, Gb, atol=1e-06) + for loss_fun in ['square_loss', 'kl_loss']: + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun) + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) @pytest.mark.parametrize('loss_fun', [ - 'square_loss', - pytest.param('kl_loss', marks=pytest.mark.xfail(raises=NotImplementedError)), + 'square_loss', 'kl_loss', pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), ]) def test_gw_semirelaxed_helper_validation(loss_fun): @@ -2149,32 +2151,33 @@ def test_semirelaxed_fgw(nx): np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) # symmetric - C1 = 0.5 * (C1 + C1.T) - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + for loss_fun in ['square_loss', 'kl_loss']: + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0) + srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) def test_semirelaxed_fgw2_gradients(): @@ -2203,37 +2206,38 @@ def test_semirelaxed_fgw2_gradients(): devices.append(torch.device("cuda")) for device in devices: # semirelaxed solvers do not support gradients over masses yet. - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) + for loss_fun in ['square_loss', 'kl_loss']: + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1) + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun) - val.backward() + val.backward() - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape - # full gradients with alpha - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) - alpha = torch.tensor(0.5, requires_grad=True, device=device) + # full gradients with alpha + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + alpha = torch.tensor(0.5, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, alpha=alpha) + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha) - val.backward() + val.backward() - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert alpha.shape == alpha.grad.shape + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert alpha.shape == alpha.grad.shape def test_srfgw_helper_backend(nx): @@ -2309,27 +2313,28 @@ def test_entropic_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) epsilon = 0.1 - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=None) + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=None) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=None) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) # symmetric C1 = 0.5 * (C1 + C1.T) @@ -2382,19 +2387,20 @@ def test_entropic_semirelaxed_gromov_dtype_device(nx): C2 /= C2.max() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) + print(nx.dtype_device(tp)) + for loss_fun in ['square_loss', 'kl_loss']: + C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) - Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) - gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) def test_entropic_semirelaxed_fgw(nx): @@ -2450,29 +2456,30 @@ def test_entropic_semirelaxed_fgw(nx): C1 = 0.5 * (C1 + C1.T) Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) @pytest.skip_backend("tf", reason="test very slow with tf backend") @@ -2505,15 +2512,16 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) - fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) + for loss_fun in ['square_loss', 'kl_loss']: + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, fgw_valb) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) def test_not_implemented_solver(): @@ -2546,17 +2554,3 @@ def test_not_implemented_solver(): with pytest.raises(ValueError): ot.gromov.entropic_fused_gromov_wasserstein( M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) - - # exact and entropic srgw and srfgw loss functions - loss_fun = 'kl_loss' - with pytest.raises(NotImplementedError): - ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, armijo=False) - with pytest.raises(NotImplementedError): - ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, epsilon=0.1) - with pytest.raises(NotImplementedError): - ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun) - with pytest.raises(NotImplementedError): - ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun, epsilon=0.1)