Skip to content

Commit

Permalink
add new API for gromov
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed Oct 18, 2023
1 parent ffdd1cf commit 223396e
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve
from .solvers import solve, solve_gromov

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand Down
1 change: 1 addition & 0 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,7 @@ class NearestBrenierPotential(BaseTransport):
ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data
ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data
"""

def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None):
self.strongly_convex_constant = strongly_convex_constant
self.gradient_lipschitz_constant = gradient_lipschitz_constant
Expand Down
10 changes: 10 additions & 0 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def entropic_gromov_wasserstein2(
learning for graph matching and node embedding. In International
Conference on Machine Learning (ICML), 2019.
"""

T, logv = entropic_gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)
Expand Down Expand Up @@ -815,12 +816,21 @@ def entropic_fused_gromov_wasserstein2(
(ICML). 2019.
"""

nx = get_backend(M, C1, C2)

T, logv = entropic_fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)

logv['T'] = T

lin_term = nx.sum(T * M)
gw_term = (logv['gw_dist'] - (1 - alpha) * lin_term) / alpha

log_fgw['quad_loss'] = gw_term * alpha
log_fgw['lin_loss'] = lin_term * (1 - alpha)

if log:
return logv['fgw_dist'], logv
else:
Expand Down
10 changes: 8 additions & 2 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
fgw_dist = log_fgw['fgw_dist']
log_fgw['T'] = T

# compute separate terms for gradients and log
lin_term = nx.sum(T * M)
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha

log_fgw['quad_loss'] = gw_term * alpha
log_fgw['lin_loss'] = lin_term * (1 - alpha)

if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
Expand All @@ -591,8 +598,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha

fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
Expand Down
264 changes: 264 additions & 0 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .bregman import sinkhorn_log
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import gromov_wasserstein2, fused_gromov_wasserstein2, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2

#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2


def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
Expand Down Expand Up @@ -345,3 +348,264 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
value_linear=value_linear, plan=plan, status=status, backend=nx)

return res


def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alpha=0.5, reg=None,
reg_type="entropy", unbalanced=None, unbalanced_type='KL',
n_threads=1, method=None, max_iter=None, plan_init=None, tol=None,
verbose=False):
r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object
The function solves the following optimization problem:
.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
The regularization is selected with `reg` (:math:`\lambda_r`) and
`reg_type`. By default ``reg=None`` and there is no regularization. The
unbalanced marginal penalization can be selected with `unbalanced`
(:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None``
and the function solves the exact optimal transport problem (respecting the
marginals).
Parameters
----------
Ca : array_like, shape (dim_a, dim_a)
Cost matrix in the source domain
Cb : array_like, shape (dim_b, dim_b)
Cost matrix in the target domain
M : array_like, shape (dim_a, dim_b), optional
Linear cost matrix for Fused Gromov-Wasserstein (default is None).
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
loss : str, optional
Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"``
symmetric : bool, optional
Use symmetric version of the Gromov-Wasserstein problem, by default None
tests wether the matrices are symmetric or True/False to avoid the test.
reg : float, optional
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:`R`, by default "entropic" (only used when
``reg!=None``)
alpha : float, optional
Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for
Gromov problem (when M is not provided). By default ``alpha=None` corresponds to to
``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused
Gromov-Wasserstein problem (``M!=None``)
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT), Not implemented yet
unbalanced_type : str, optional
Type of unbalanced penalization unction :math:`U` either "KL", "L2",
"TV", by default "KL" , Not implemented yet
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iteration, by default None (default values in each
solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
tol : float, optional
Tolerance for solution precision, by default None (default values in
each solvers)
verbose : bool, optional
Print information in the solver, by default False
Returns
-------
res : OTResult()
Result of the optimization problem. The information can be obtained as follows:
- res.plan : OT plan :math:`\mathbf{T}`
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan
See :any:`OTResult` for more information.
Notes
-----
The following methods are available for solving the Gromov-Wasserstein
problem:
"""

# detect backend
nx = get_backend(Ca, Cb, M, a, b)

# create uniform weights if not given
if a is None:
a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0]
if b is None:
b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1]

# default values for solutions
potentials = None
value = None
value_linear = None
value_quad = None
plan = None
status = None

loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'}

if reg is None or reg == 0: # exact OT

if unbalanced is None: # Exact balanced OT

if M is None or alpha == 1: # Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9

value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_quad = value
if alpha == 1: # set to 0 for FGW with alpha=1
value_linear = 0
plan = log['T']
potentials = (log['u'], log['v'])

elif alpha == 0: # Wasserstein problem

# default values for EMD solver
if max_iter is None:
max_iter = 1000000

value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads)

value = value_linear
potentials = (log['u'], log['v'])
plan = log['G']
status = log["warning"] if log["warning"] is not None else 'Converged'
value_quad = 0

else: # Fused Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9

value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_linear = log['lin_loss']
value_quad = log['quad_loss']
plan = log['T']
potentials = (log['u'], log['v'])

elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT

raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type)))

else:
raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type)))

else: # regularized OT

if unbalanced is None: # Balanced regularized OT

if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if method is None:
method = 'PGD'

value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

plan = log['T']
value_linear = 0
value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16))
# potentials = (log['log_u'], log['log_v']) #TODO

elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter,
stopThr=tol, log=True,
verbose=verbose)

value_linear = nx.sum(M * plan)
value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
potentials = (log['log_u'], log['log_v'])

elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if method is None:
method = 'PGD'

value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_linear = log['lin_loss']
value_quad = log['quad_loss']
plan = log['T']
# potentials = (log['u'], log['v'])
value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16))

else:
raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type)))

else: # unbalanced AND regularized OT

raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))

# if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':

# if max_iter is None:
# max_iter = 1000
# if tol is None:
# tol = 1e-9

# plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)

# value_linear = nx.sum(M * plan)

# value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))

# potentials = (log['logu'], log['logv'])

# elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']:

# if max_iter is None:
# max_iter = 1000
# if tol is None:
# tol = 1e-12

# plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)

# value_linear = nx.sum(M * plan)

# value = log['loss']

# else:
# raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))

res = OTResult(potentials=potentials, value=value,
value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx)

return res
14 changes: 12 additions & 2 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,12 @@ class UndefinedParameter(Exception):


class OTResult:
def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):

self._potentials = potentials
self._value = value
self._value_linear = value_linear
self._value_quad = value_quad
self._plan = plan
self._log = log
self._sparse_plan = sparse_plan
Expand Down Expand Up @@ -828,7 +829,8 @@ def lazy_plan(self):

@property
def value(self):
"""Full transport cost, including possible regularization terms."""
"""Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions."""
if self._value is not None:
return self._value
else:
Expand All @@ -842,6 +844,14 @@ def value_linear(self):
else:
raise NotImplementedError()

@property
def value_quad(self):
"""The quadratic part of the transport cost for Gromov-Wasserstein solutions."""
if self._value_quad is not None:
return self._value_quad
else:
raise NotImplementedError()

# Marginal constraints -------------------------
@property
def marginals(self):
Expand Down
Loading

0 comments on commit 223396e

Please sign in to comment.