Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add simple low rank lazytensor creator and add it to log for actored OT
Browse files Browse the repository at this point in the history
rflamary committed Oct 30, 2023
1 parent 89fdc1f commit d548730
Showing 3 changed files with 73 additions and 4 deletions.
3 changes: 2 additions & 1 deletion ot/factored.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
# License: MIT License

from .backend import get_backend
from .utils import dist
from .utils import dist, get_lowrank_lazytensor
from .lp import emd
from .bregman import sinkhorn

@@ -139,6 +139,7 @@ def solve_ot(X1, X2, w1, w2):
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
'lazy_plan': get_lowrank_lazytensor(Ga, Gb.T, nx=nx),
}
return Ga, Gb, X, log_dic

25 changes: 25 additions & 0 deletions ot/utils.py
Original file line number Diff line number Diff line change
@@ -565,6 +565,31 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
raise (NotImplementedError("Only axis=None is implemented for now."))

Check warning on line 565 in ot/utils.py

Codecov / codecov/patch

ot/utils.py#L565

Added line #L565 was not covered by tests


def get_lowrank_lazytensor(Q, R, d=None, nx=None):
""" Get a lowrank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T"""

if nx is None:
nx = get_backend(Q, R, d)

shape = (Q.shape[0], R.shape[0])

if d is None:

def func(i, j, Q, R):
return nx.dot(Q[i], R[j].T)

T = LazyTensor(shape, func, Q=Q, R=R)

else:

def func(i, j, Q, R, d):
return nx.dot(Q[i] * d[None, :], R[j].T)

T = LazyTensor(shape, func, Q=Q, R=R, d=d)

return T


def get_parameter_pair(parameter):
r"""Extract a pair of parameters from a given parameter
Used in unbalanced OT and COOT solvers
49 changes: 46 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ def getitem(i, j, a, b):
# create a lazy tensor
T = ot.utils.LazyTensor((n1, n2), getitem, a=a, b=b)

return T
return T, a, b


def test_proj_simplex(nx):
@@ -465,7 +465,7 @@ def getitem(i, j, x1, x2):

def test_OTResult_LazyTensor(nx):

T = get_LazyTensor(nx)
T, a, b = get_LazyTensor(nx)

res = ot.utils.OTResult(lazy_plan=T, batch_size=9, backend=nx)

@@ -475,7 +475,7 @@ def test_OTResult_LazyTensor(nx):

def test_LazyTensor_reduce(nx):

T = get_LazyTensor(nx)
T, a, b = get_LazyTensor(nx)

T0 = T[:]
s0 = nx.sum(T0)
@@ -506,3 +506,46 @@ def test_LazyTensor_reduce(nx):
s = ot.utils.reduce_lazytensor(T, nx.logsumexp, axis=1, nx=nx)
s2 = nx.logsumexp(T[:], axis=1)
np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2))


def test_lowrank_LazyTensor(nx):

p = 5
n1 = 100
n2 = 200

shape = (n1, n2)

rng = np.random.RandomState(42)
X1 = rng.randn(n1, p)
X2 = rng.randn(n2, p)
diag_d = rng.rand(p)

X1, X2, diag_d = nx.from_numpy(X1, X2, diag_d)

T0 = nx.dot(X1, X2.T)

T = ot.utils.get_lowrank_lazytensor(X1, X2)

np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))

assert T.Q is X1
assert T.R is X2

# get the full tensor (not lazy)
assert T[:].shape == shape

# get one component
assert T[1, 1] == nx.dot(X1[1], X2[1].T)

# get one row
assert T[1].shape == (n2,)

# get one column with slices
assert T[::10, 5].shape == (10,)

T0 = nx.dot(X1 * diag_d[None, :], X2.T)

T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx)

np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))

0 comments on commit d548730

Please sign in to comment.