diff --git a/ot/factored.py b/ot/factored.py index 8d6615876..0d8fcb40d 100644 --- a/ot/factored.py +++ b/ot/factored.py @@ -7,7 +7,7 @@ # License: MIT License from .backend import get_backend -from .utils import dist +from .utils import dist, get_lowrank_lazytensor from .lp import emd from .bregman import sinkhorn @@ -139,6 +139,7 @@ def solve_ot(X1, X2, w1, w2): 'vb': logb['v'], 'costa': loga['cost'], 'costb': logb['cost'], + 'lazy_plan': get_lowrank_lazytensor(Ga, Gb.T, nx=nx), } return Ga, Gb, X, log_dic diff --git a/ot/utils.py b/ot/utils.py index 08654c98e..07845cb9e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -565,6 +565,31 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): raise (NotImplementedError("Only axis=None is implemented for now.")) +def get_lowrank_lazytensor(Q, R, d=None, nx=None): + """ Get a lowrank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T""" + + if nx is None: + nx = get_backend(Q, R, d) + + shape = (Q.shape[0], R.shape[0]) + + if d is None: + + def func(i, j, Q, R): + return nx.dot(Q[i], R[j].T) + + T = LazyTensor(shape, func, Q=Q, R=R) + + else: + + def func(i, j, Q, R, d): + return nx.dot(Q[i] * d[None, :], R[j].T) + + T = LazyTensor(shape, func, Q=Q, R=R, d=d) + + return T + + def get_parameter_pair(parameter): r"""Extract a pair of parameters from a given parameter Used in unbalanced OT and COOT solvers diff --git a/test/test_utils.py b/test/test_utils.py index 96249753f..ef146f3a0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -28,7 +28,7 @@ def getitem(i, j, a, b): # create a lazy tensor T = ot.utils.LazyTensor((n1, n2), getitem, a=a, b=b) - return T + return T, a, b def test_proj_simplex(nx): @@ -465,7 +465,7 @@ def getitem(i, j, x1, x2): def test_OTResult_LazyTensor(nx): - T = get_LazyTensor(nx) + T, a, b = get_LazyTensor(nx) res = ot.utils.OTResult(lazy_plan=T, batch_size=9, backend=nx) @@ -475,7 +475,7 @@ def test_OTResult_LazyTensor(nx): def test_LazyTensor_reduce(nx): - T = get_LazyTensor(nx) + T, a, b = get_LazyTensor(nx) T0 = T[:] s0 = nx.sum(T0) @@ -506,3 +506,46 @@ def test_LazyTensor_reduce(nx): s = ot.utils.reduce_lazytensor(T, nx.logsumexp, axis=1, nx=nx) s2 = nx.logsumexp(T[:], axis=1) np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2)) + + +def test_lowrank_LazyTensor(nx): + + p = 5 + n1 = 100 + n2 = 200 + + shape = (n1, n2) + + rng = np.random.RandomState(42) + X1 = rng.randn(n1, p) + X2 = rng.randn(n2, p) + diag_d = rng.rand(p) + + X1, X2, diag_d = nx.from_numpy(X1, X2, diag_d) + + T0 = nx.dot(X1, X2.T) + + T = ot.utils.get_lowrank_lazytensor(X1, X2) + + np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + + assert T.Q is X1 + assert T.R is X2 + + # get the full tensor (not lazy) + assert T[:].shape == shape + + # get one component + assert T[1, 1] == nx.dot(X1[1], X2[1].T) + + # get one row + assert T[1].shape == (n2,) + + # get one column with slices + assert T[::10, 5].shape == (10,) + + T0 = nx.dot(X1 * diag_d[None, :], X2.T) + + T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) + + np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))