Skip to content

Commit

Permalink
better tests corverage
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed Oct 30, 2023
1 parent d548730 commit 7cf5651
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ot/factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def solve_ot(X1, X2, w1, w2):
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
'lazy_plan': get_lowrank_lazytensor(Ga, Gb.T, nx=nx),
'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx),
}
return Ga, Gb, X, log_dic

Expand Down
2 changes: 1 addition & 1 deletion ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
return res

else:
raise (NotImplementedError("Only axis=None is implemented for now."))
raise (NotImplementedError("Only axis=None, 0 or 1 is implemented for now."))


def get_lowrank_lazytensor(Q, R, d=None, nx=None):
Expand Down
1 change: 1 addition & 0 deletions test/test_factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_factored_ot():
# check constraints
np.testing.assert_allclose(u, Ga.sum(1))
np.testing.assert_allclose(u, Gb.sum(0))
np.testing.assert_allclose(1, log['lazy_plan'][:].sum())


def test_factored_ot_backends(nx):
Expand Down
20 changes: 20 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ def getitem(i, j, x1, x2):
# get one column with slices
assert T[::10, 5].shape == (10,)

with pytest.raises(NotImplementedError):
T["error"]


def test_OTResult_LazyTensor(nx):

Expand Down Expand Up @@ -507,6 +510,23 @@ def test_LazyTensor_reduce(nx):
s2 = nx.logsumexp(T[:], axis=1)
np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2))

# test 3D tensors
def getitem(i, j, k, a, b, c):
return a[i, None, None] * b[None, j, None] * c[None, None, k]

# create a lazy tensor
n = a.shape[0]
T = ot.utils.LazyTensor((n, n, n), getitem, a=a, b=a, c=a)

# total sum
s1 = ot.utils.reduce_lazytensor(T, nx.sum, axis=0, nx=nx)
s2 = ot.utils.reduce_lazytensor(T, nx.sum, axis=1, nx=nx)

np.testing.assert_allclose(nx.to_numpy(s1), nx.to_numpy(s2))

with pytest.raises(NotImplementedError):
ot.utils.reduce_lazytensor(T, nx.sum, axis=2, nx=nx, batch_size=10)


def test_lowrank_LazyTensor(nx):

Expand Down

0 comments on commit 7cf5651

Please sign in to comment.