Skip to content

Commit

Permalink
clean ot.solve_sample and remve lazy test cause not ilplemented yet
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed Nov 7, 2023
1 parent c7e899f commit 9403851
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 102 deletions.
120 changes: 58 additions & 62 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2




#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2


def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None,
potentials_init=None, tol=None, verbose=False):
Expand Down Expand Up @@ -853,14 +848,9 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
return res




##### new ot.solve_sample function

def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None,
potentials_init=None, tol=None, verbose=False):

def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None,
potentials_init=None, tol=None, verbose=False):
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.
The function solves the following general optimal transport problem
Expand All @@ -870,6 +860,10 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
where the cost matrix :math:`\mathbf{M}` is computed from the samples in the
source and target domains wuch that :math:`M_{i,j} = d(x_i,y_j)` where
:math:`d` is a metric (by default the squared Euclidean distance).
The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By
default ``reg=None`` and there is no regularization. The unbalanced marginal
penalization can be selected with `unbalanced` (:math:`\lambda_u`) and
Expand All @@ -881,7 +875,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
samples in the target domain
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Expand All @@ -896,8 +890,16 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL"
is_Lazy : bool, optional
Return :any:`OTResultlazy` object to reduce memory cost when True, by default False
lazy : bool, optional
Return :any:`OTResultlazy` object to reduce memory cost when True, by
default False
batch_size : int, optional
Batch size for lazy solver, by default None (default values in each
solvers)
method : str, optional
Method for solving the problem, this can be used to select the solver
for unalanced problems (see :any:`ot.solve`), or to select a specific
lazy large scale solver.
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Expand Down Expand Up @@ -925,68 +927,62 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
See :any:`OTResult` for more information.
"""

# Detect backend
arr = [X_s,X_t]
if a is not None:
arr.append(a)
if b is not None:
arr.append(b)
nx = get_backend(*arr)

# create uniform weights if not given
ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)
if method is not None and method.lower() in ['1d', 'sliced', 'lowrank', 'factored']:
lazy = True

if metric != 'sqeuclidean':
raise (NotImplementedError('Not implemented metric = {} (only sqeulidean)'.format(metric)))
if not lazy: # default non lazy solver calls ot.solve

# compute cost matrix M and use solve function
M = dist(X_a, X_b, metric)

# default values for solutions
potentials = None
lazy_plan = None
res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose)

return res

else:

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if batch_size is None:
batch_size = 100
# Detect backend
nx = get_backend(X_a, X_b, a, b)

# default values for solutions
potentials = None
value = None
value_linear = None
plan = None
lazy_plan = None
status = None

if is_Lazy:
################# WIP ####################
if reg is None or reg == 0: # EMD solver for isLazy ?
if unbalanced is None: # balanced EMD solver for isLazy ?
raise (NotImplementedError('Not implemented balanced with no regularization'))
if reg is None or reg == 0: # EMD solver for isLazy ?

if unbalanced is None: # balanced EMD solver for isLazy ?
raise (NotImplementedError('Exact OT solver with lazy=True not implemented'))

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type)))

raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type)))

#############################################
else:

else:
if unbalanced is None:
u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol,
isLazy=True, batchSize=batch_size, verbose=verbose, log=True)

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if batch_size is None:
batch_size = 100

u, v, log = empirical_sinkhorn(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol,
isLazy=True, batchSize=batch_size, verbose=verbose, log=True)
# compute potentials
potentials = (log["u"], log["v"])
potentials = (u, v)

# compute lazy_plan
# ...

raise (NotImplementedError('Not implemented balanced with regularization'))

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type)))

else:
# compute cost matrix M and use solve function
M = dist(X_s, X_t, metric)

res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose)
return res
3 changes: 2 additions & 1 deletion ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,7 @@ def citation(self):
}
"""


class LazyTensor(object):
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
defined by a function that computes its values on the fly from slices.
Expand Down Expand Up @@ -1232,4 +1233,4 @@ def __getitem__(self, key):
return self._getitem(*k, **self.kwargs)

def __repr__(self):
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
73 changes: 35 additions & 38 deletions test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,9 @@ def test_solve_gromov_not_implemented(nx):
ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False)




######## Test functions for ot.solve_sample ########


@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
def test_solve_sample(nx):
# test solve_sample when is_Lazy = False
n = 100
Expand All @@ -272,8 +269,13 @@ def test_solve_sample(nx):
a = ot.utils.unif(X_s.shape[0])
b = ot.utils.unif(X_t.shape[0])

M = ot.dist(X_s, X_t)

# solve with ot.solve
sol00 = ot.solve(M, a, b)

# solve unif weights
sol0 = ot.solve_sample(X_s, X_t)
sol0 = ot.solve_sample(X_s, X_t)

# solve signe weights
sol = ot.solve_sample(X_s, X_t, a, b)
Expand All @@ -285,6 +287,7 @@ def test_solve_sample(nx):
sol.status

assert_allclose_sol(sol0, sol)
assert_allclose_sol(sol0, sol00)

# solve in backend
X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
Expand All @@ -301,47 +304,41 @@ def test_solve_sample(nx):
sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence')


# def test_lazy_solve_sample(nx):
# # test solve_sample when is_Lazy = True
# n = 100
# X_s = np.reshape(1.0 * np.arange(n), (n, 1))
# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))

def test_lazy_solve_sample(nx):
# test solve_sample when is_Lazy = True
n = 100
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))

a = ot.utils.unif(X_s.shape[0])
b = ot.utils.unif(X_t.shape[0])

# solve unif weights
sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True

# solve signe weights
sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True)

# check some attributes
sol.potentials
sol.lazy_plan
# a = ot.utils.unif(X_s.shape[0])
# b = ot.utils.unif(X_t.shape[0])

assert_allclose_sol(sol0, sol)
# # solve unif weights
# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True

# solve in backend
X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True)

assert_allclose_sol(sol, solb)
# # solve signe weights
# sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, lazy=True)

# test not implemented reg==0 (or None) + balanced and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default
# # check some attributes
# sol.potentials
# sol.lazy_plan

# test not implemented reg==0 (or None) + unbalanced_type and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default

# test not implemented reg != 0 + unbalanced_type and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True)
# assert_allclose_sol(sol0, sol)

# # solve in backend
# X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
# solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, lazy=True)

# assert_allclose_sol(sol, solb)

# # test not implemented reg==0 (or None) + balanced and check raise
# with pytest.raises(NotImplementedError):
# sol0 = ot.solve_sample(X_s, X_t, lazy=True) # reg == 0 (or None) + unbalanced= None are default

# # test not implemented reg==0 (or None) + unbalanced_type and check raise
# with pytest.raises(NotImplementedError):
# sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", lazy=True) # reg == 0 (or None) is default

# # test not implemented reg != 0 + unbalanced_type and check raise
# with pytest.raises(NotImplementedError):
# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", lazy=True)
1 change: 0 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,4 +569,3 @@ def test_lowrank_LazyTensor(nx):
T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx)

np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))

0 comments on commit 9403851

Please sign in to comment.