From 6d0a888b7804d7cbce0845a384ab37c458086fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Fri, 27 Dec 2024 11:12:48 -0300 Subject: [PATCH 01/11] feat: add _get_convol_img_fn --- ot/bregman/_convolutional.py | 97 +++++++++++++++--------------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 0e6548710..9a9895d82 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -10,8 +10,36 @@ import warnings -from ..utils import list_to_array from ..backend import get_backend +from ..utils import list_to_array + + +def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): + """ + Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions. + """ + t1 = nx.linspace(0, 1, width, type_as=type_as) + Y1, X1 = nx.meshgrid(t1, t1) + M1 = -((X1 - Y1) ** 2) / reg + + t2 = nx.linspace(0, 1, height, type_as=type_as) + Y2, X2 = nx.meshgrid(t2, t2) + M2 = -((X2 - Y2) ** 2) / reg + + def convol_imgs(log_imgs): + log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) + log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T + return log_imgs + + if not log_domain: + K1, K2 = nx.exp(M1), nx.exp(M2) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + return convol_imgs def convolutional_barycenter2d( @@ -133,37 +161,26 @@ def _convolutional_barycenter2d( """ A = list_to_array(A) + n_hists, width, height = A.shape nx = get_backend(A) if weights is None: - weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] + weights = nx.ones((n_hists,), type_as=A) / n_hists else: - assert len(weights) == A.shape[0] + assert len(weights) == n_hists if log: log = {"err": []} - bar = nx.ones(A.shape[1:], type_as=A) + bar = nx.ones((width, height), type_as=A) bar /= nx.sum(bar) U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, A.shape[1], type_as=A) - [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-((X - Y) ** 2) / reg) - - t = nx.linspace(0, 1, A.shape[2], type_as=A) - [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-((X - Y) ** 2) / reg) - - def convol_imgs(imgs): - kx = nx.einsum("...ij,kjl->kil", K1, imgs) - kxy = nx.einsum("...ij,klj->kli", K2, kx) - return kxy + convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A) KU = convol_imgs(U) for ii in range(numItermax): @@ -195,6 +212,7 @@ def convol_imgs(imgs): if log: log["niter"] = ii log["U"] = U + log["V"] = V return bar, log else: return bar @@ -236,19 +254,7 @@ def _convolutional_barycenter2d_log( err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M1 = -((X - Y) ** 2) / reg - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M2 = -((X - Y) ** 2) / reg - - def convol_img(log_img): - log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) - log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T - return log_img + convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True) logA = nx.log(A + stabThr) log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) @@ -417,23 +423,11 @@ def _convolutional_barycenter2d_debiased( bar /= width * height U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) - c = nx.ones(A.shape[1:], type_as=A) + c = nx.ones((width, height), type_as=A) err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-((X - Y) ** 2) / reg) - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-((X - Y) ** 2) / reg) - - def convol_imgs(imgs): - kx = nx.einsum("...ij,kjl->kil", K1, imgs) - kxy = nx.einsum("...ij,klj->kli", K2, kx) - return kxy + convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A) KU = convol_imgs(U) for ii in range(numItermax): @@ -471,6 +465,7 @@ def convol_imgs(imgs): if log: log["niter"] = ii log["U"] = U + log["V"] = V return bar, log else: return bar @@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log( err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M1 = -((X - Y) ** 2) / reg - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M2 = -((X - Y) ** 2) / reg - - def convol_img(log_img): - log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) - log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T - return log_img + convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True) logA = nx.log(A + stabThr) log_bar, c = nx.zeros((2, width, height), type_as=A) From 698321eb2973a3c5490dae3d59751af70acb97d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Fri, 27 Dec 2024 11:17:22 -0300 Subject: [PATCH 02/11] refactor: add warning msg --- ot/bregman/_convolutional.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 9a9895d82..a3b999f22 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -13,11 +13,15 @@ from ..backend import get_backend from ..utils import list_to_array +_warning_msg = ( + "Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`." +) + def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): - """ - Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions. - """ + """Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions.""" t1 = nx.linspace(0, 1, width, type_as=type_as) Y1, X1 = nx.meshgrid(t1, t1) M1 = -((X1 - Y1) ** 2) / reg @@ -204,11 +208,7 @@ def _convolutional_barycenter2d( else: if warn: - warnings.warn( - "Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii log["U"] = U @@ -282,11 +282,7 @@ def _convolutional_barycenter2d_log( else: if warn: - warnings.warn( - "Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii return nx.exp(log_bar), log @@ -457,11 +453,7 @@ def _convolutional_barycenter2d_debiased( break else: if warn: - warnings.warn( - "Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii log["U"] = U @@ -534,11 +526,7 @@ def _convolutional_barycenter2d_debiased_log( else: if warn: - warnings.warn( - "Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii return nx.exp(log_bar), log From 2dd4610e9ef00725804143645a1630f6a7cea4cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Fri, 27 Dec 2024 11:21:30 -0300 Subject: [PATCH 03/11] refactor: encapsulate the report printing in a function --- ot/bregman/_convolutional.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index a3b999f22..11a1c5903 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -46,6 +46,13 @@ def convol_imgs(imgs): return convol_imgs +def _print_report(ii, err): + """Print the report of the iteration.""" + if ii % 200 == 0: + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) + + def convolutional_barycenter2d( A, reg, @@ -198,11 +205,8 @@ def _convolutional_barycenter2d( # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) if err < stopThr: break @@ -271,11 +275,8 @@ def _convolutional_barycenter2d_log( # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) if err < stopThr: break G = log_bar[None, :, :] - log_KU @@ -441,11 +442,8 @@ def _convolutional_barycenter2d_debiased( # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping @@ -515,11 +513,8 @@ def _convolutional_barycenter2d_debiased_log( # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) if err < stopThr and ii > 20: break G = log_bar[None, :, :] - log_KU From 138bacb37091a142d40747bb26cc738d8ebaafe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Fri, 27 Dec 2024 12:16:46 -0300 Subject: [PATCH 04/11] docs: add some documentation in the function _get_convol_img_fn --- ot/bregman/_convolutional.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 11a1c5903..d4ad92466 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -21,7 +21,9 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): - """Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions.""" + """Return the convolution operator for 2D images. + + The function constructed is equivalent to blurring on horizontal then vertical directions.""" t1 = nx.linspace(0, 1, width, type_as=type_as) Y1, X1 = nx.meshgrid(t1, t1) M1 = -((X1 - Y1) ** 2) / reg @@ -30,11 +32,13 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): Y2, X2 = nx.meshgrid(t2, t2) M2 = -((X2 - Y2) ** 2) / reg + # As M1 and M2 are computed first, we can use them to compute the convolution in log-domain def convol_imgs(log_imgs): log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T return log_imgs + # If normal domain is selected, we can use M1 and M2 to compute the convolution if not log_domain: K1, K2 = nx.exp(M1), nx.exp(M2) From dadb470bf861beb9f5269bf178ce4daf4b0e2b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Fri, 27 Dec 2024 12:19:29 -0300 Subject: [PATCH 05/11] docs: add realise --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index f3f100b66..0ddac599b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -40,6 +40,7 @@ This release also contains few bug fixes, concerning the support of any metric i - Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663) - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) +- Refactored `ot.bregman._convolutional` to improve readability (PR #709) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From 1030815e1b401336de2241211f71df386eb37a5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <femunoz@dim.uchile.cl> Date: Thu, 2 Jan 2025 13:34:48 -0300 Subject: [PATCH 06/11] refactor: change function _get_convol_img_fn for more clarity --- ot/bregman/_convolutional.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index d4ad92466..c8793cb25 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -31,13 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): t2 = nx.linspace(0, 1, height, type_as=type_as) Y2, X2 = nx.meshgrid(t2, t2) M2 = -((X2 - Y2) ** 2) / reg - - # As M1 and M2 are computed first, we can use them to compute the convolution in log-domain - def convol_imgs(log_imgs): - log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) - log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T - return log_imgs - + # If normal domain is selected, we can use M1 and M2 to compute the convolution if not log_domain: K1, K2 = nx.exp(M1), nx.exp(M2) @@ -47,6 +41,13 @@ def convol_imgs(imgs): kxy = nx.einsum("...ij,klj->kli", K2, kx) return kxy + # Else, we can use M1 and M2 to compute the convolution in log-domain + else: + def convol_imgs(log_imgs): + log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) + log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T + return log_imgs + return convol_imgs From a70cf8dc5f2d25184b4d0d6b8198b3f9b2722020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Fri, 3 Jan 2025 09:02:14 -0300 Subject: [PATCH 07/11] refactor: run pre-commit --- ot/bregman/_convolutional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index c8793cb25..412cca9e5 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -31,7 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): t2 = nx.linspace(0, 1, height, type_as=type_as) Y2, X2 = nx.meshgrid(t2, t2) M2 = -((X2 - Y2) ** 2) / reg - + # If normal domain is selected, we can use M1 and M2 to compute the convolution if not log_domain: K1, K2 = nx.exp(M1), nx.exp(M2) @@ -43,6 +43,7 @@ def convol_imgs(imgs): # Else, we can use M1 and M2 to compute the convolution in log-domain else: + def convol_imgs(log_imgs): log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T From 80d954255c76d2bd240b0e403b64fd2c1a384fc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Wed, 8 Jan 2025 20:10:38 -0300 Subject: [PATCH 08/11] test: refactor tests to delete the error for unavailable backends --- test/test_bregman.py | 236 ++++++++++++++++++------------------------- 1 file changed, 99 insertions(+), 137 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c0c0e8f2..d3dbb5b11 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -825,22 +825,18 @@ def test_wasserstein_bary_2d(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) - else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True - ) - bary_wass = nx.to_numpy( - ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) - ) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True + ) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + ) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) @pytest.skip_backend("tf") @@ -856,27 +852,23 @@ def test_wasserstein_bary_2d_dtype_device(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - else: - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) @pytest.mark.skipif(not tf, reason="tf not installed") @@ -894,37 +886,6 @@ def test_wasserstein_bary_2d_device_tf(method): # wasserstein reg = 1e-2 - if method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - else: - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) - - # Check that everything happens on the GPU - Ab = nx.from_numpy(A) - - # wasserstein - reg = 1e-2 - if method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( A, reg, method=method, verbose=True, log=True @@ -943,9 +904,32 @@ def test_wasserstein_bary_2d_device_tf(method): # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) - # Check this only if GPU is available - if len(tf.config.list_physical_devices("GPU")) > 0: - assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + # Check this only if GPU is available + if len(tf.config.list_physical_devices("GPU")) > 0: + assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -957,22 +941,18 @@ def test_wasserstein_bary_2d_debiased(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) - else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True - ) - bary_wass = nx.to_numpy( - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) - ) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True + ) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + ) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) @pytest.skip_backend("tf") @@ -988,31 +968,25 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) - else: - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( - Ab, reg, method=method - ) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( + Ab, reg, method=method + ) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased( - A, reg, log=True, verbose=True - ) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) @pytest.mark.skipif(not tf, reason="tf not installed") @@ -1030,41 +1004,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): # wasserstein reg = 1e-2 - if method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) - else: - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( - Ab, reg, method=method - ) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased( - A, reg, log=True, verbose=True - ) - - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) - - # Check that everything happens on the GPU - Ab = nx.from_numpy(A) - - # wasserstein - reg = 1e-2 - if method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) - else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( A, reg, method=method, verbose=True, log=True @@ -1077,6 +1016,29 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): bary_wass = nx.to_numpy(bary_wass_b) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) def test_unmix(nx): From e8d4de5c105df7dcc79c3ac9a9e3784b46ce3ac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Wed, 8 Jan 2025 20:12:42 -0300 Subject: [PATCH 09/11] feat: delete not implemented error in convolutional module --- ot/bregman/_convolutional.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 412cca9e5..6f0de8330 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -246,11 +246,6 @@ def _convolutional_barycenter2d_log( A = list_to_array(A) nx = get_backend(A) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) n_hists, width, height = A.shape @@ -483,11 +478,7 @@ def _convolutional_barycenter2d_debiased_log( A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) + if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: From b0c5ef335a4aa8e9063f34cb39acc52a028f27d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Wed, 8 Jan 2025 22:21:21 -0300 Subject: [PATCH 10/11] revert the last two commits --- ot/bregman/_convolutional.py | 11 +- test/test_bregman.py | 236 ++++++++++++++++++++--------------- 2 files changed, 147 insertions(+), 100 deletions(-) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 6f0de8330..412cca9e5 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -246,6 +246,11 @@ def _convolutional_barycenter2d_log( A = list_to_array(A) nx = get_backend(A) + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) n_hists, width, height = A.shape @@ -478,7 +483,11 @@ def _convolutional_barycenter2d_debiased_log( A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: diff --git a/test/test_bregman.py b/test/test_bregman.py index d3dbb5b11..6c0c0e8f2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -825,18 +825,22 @@ def test_wasserstein_bary_2d(nx, method): # wasserstein reg = 1e-2 - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True - ) - bary_wass = nx.to_numpy( - ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) - ) + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + else: + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True + ) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + ) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) @pytest.skip_backend("tf") @@ -852,23 +856,27 @@ def test_wasserstein_bary_2d_dtype_device(nx, method): # wasserstein reg = 1e-2 - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) @pytest.mark.skipif(not tf, reason="tf not installed") @@ -886,6 +894,37 @@ def test_wasserstein_bary_2d_device_tf(method): # wasserstein reg = 1e-2 + if method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( A, reg, method=method, verbose=True, log=True @@ -904,32 +943,9 @@ def test_wasserstein_bary_2d_device_tf(method): # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) - # Check that everything happens on the GPU - Ab = nx.from_numpy(A) - - # wasserstein - reg = 1e-2 - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) - - # Check this only if GPU is available - if len(tf.config.list_physical_devices("GPU")) > 0: - assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") + # Check this only if GPU is available + if len(tf.config.list_physical_devices("GPU")) > 0: + assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -941,18 +957,22 @@ def test_wasserstein_bary_2d_debiased(nx, method): # wasserstein reg = 1e-2 - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True - ) - bary_wass = nx.to_numpy( - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) - ) + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + else: + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True + ) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + ) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) @pytest.skip_backend("tf") @@ -968,25 +988,31 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): # wasserstein reg = 1e-2 - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( - Ab, reg, method=method - ) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( + Ab, reg, method=method + ) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased( + A, reg, log=True, verbose=True + ) - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) @pytest.mark.skipif(not tf, reason="tf not installed") @@ -1004,6 +1030,41 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): # wasserstein reg = 1e-2 + if method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True + ) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( + Ab, reg, method=method + ) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased( + A, reg, log=True, verbose=True + ) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( A, reg, method=method, verbose=True, log=True @@ -1016,29 +1077,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): bary_wass = nx.to_numpy(bary_wass_b) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) - - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) - - # Check that everything happens on the GPU - Ab = nx.from_numpy(A) - - # wasserstein - reg = 1e-2 - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True - ) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) def test_unmix(nx): From 667ea6de4823c579f77636917010c897ac42fea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl> Date: Wed, 8 Jan 2025 22:22:28 -0300 Subject: [PATCH 11/11] docs: add comments with the reason of the error --- ot/bregman/_convolutional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 412cca9e5..9a8253240 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -246,6 +246,8 @@ def _convolutional_barycenter2d_log( A = list_to_array(A) nx = get_backend(A) + # This error is raised because we are using mutable assignment in the line + # `log_KU[k] = ...` which is not allowed in Jax and TF. if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" @@ -483,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log( A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) + # This error is raised because we are using mutable assignment in the line + # `log_KU[k] = ...` which is not allowed in Jax and TF. if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented"