diff --git a/ot/factored.py b/ot/factored.py index 0d8fcb40d..65613d328 100644 --- a/ot/factored.py +++ b/ot/factored.py @@ -139,7 +139,7 @@ def solve_ot(X1, X2, w1, w2): 'vb': logb['v'], 'costa': loga['cost'], 'costb': logb['cost'], - 'lazy_plan': get_lowrank_lazytensor(Ga, Gb.T, nx=nx), + 'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx), } return Ga, Gb, X, log_dic diff --git a/ot/utils.py b/ot/utils.py index 07845cb9e..480d3e964 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -562,7 +562,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): return res else: - raise (NotImplementedError("Only axis=None is implemented for now.")) + raise (NotImplementedError("Only axis=None, 0 or 1 is implemented for now.")) def get_lowrank_lazytensor(Q, R, d=None, nx=None): diff --git a/test/test_factored.py b/test/test_factored.py index fd2fd0133..5cfc997ef 100644 --- a/test/test_factored.py +++ b/test/test_factored.py @@ -28,6 +28,7 @@ def test_factored_ot(): # check constraints np.testing.assert_allclose(u, Ga.sum(1)) np.testing.assert_allclose(u, Gb.sum(0)) + np.testing.assert_allclose(1, log['lazy_plan'][:].sum()) def test_factored_ot_backends(nx): diff --git a/test/test_utils.py b/test/test_utils.py index ef146f3a0..3a9d590ab 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -462,6 +462,9 @@ def getitem(i, j, x1, x2): # get one column with slices assert T[::10, 5].shape == (10,) + with pytest.raises(NotImplementedError): + T["error"] + def test_OTResult_LazyTensor(nx): @@ -507,6 +510,23 @@ def test_LazyTensor_reduce(nx): s2 = nx.logsumexp(T[:], axis=1) np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2)) + # test 3D tensors + def getitem(i, j, k, a, b, c): + return a[i, None, None] * b[None, j, None] * c[None, None, k] + + # create a lazy tensor + n = a.shape[0] + T = ot.utils.LazyTensor((n, n, n), getitem, a=a, b=a, c=a) + + # total sum + s1 = ot.utils.reduce_lazytensor(T, nx.sum, axis=0, nx=nx) + s2 = ot.utils.reduce_lazytensor(T, nx.sum, axis=1, nx=nx) + + np.testing.assert_allclose(nx.to_numpy(s1), nx.to_numpy(s2)) + + with pytest.raises(NotImplementedError): + ot.utils.reduce_lazytensor(T, nx.sum, axis=2, nx=nx, batch_size=10) + def test_lowrank_LazyTensor(nx):