From 6d0a888b7804d7cbce0845a384ab37c458086fcc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Fri, 27 Dec 2024 11:12:48 -0300
Subject: [PATCH 01/11] feat: add _get_convol_img_fn

---
 ot/bregman/_convolutional.py | 97 +++++++++++++++---------------------
 1 file changed, 40 insertions(+), 57 deletions(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index 0e6548710..9a9895d82 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -10,8 +10,36 @@
 
 import warnings
 
-from ..utils import list_to_array
 from ..backend import get_backend
+from ..utils import list_to_array
+
+
+def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
+    """
+    Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions.
+    """
+    t1 = nx.linspace(0, 1, width, type_as=type_as)
+    Y1, X1 = nx.meshgrid(t1, t1)
+    M1 = -((X1 - Y1) ** 2) / reg
+
+    t2 = nx.linspace(0, 1, height, type_as=type_as)
+    Y2, X2 = nx.meshgrid(t2, t2)
+    M2 = -((X2 - Y2) ** 2) / reg
+
+    def convol_imgs(log_imgs):
+        log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
+        log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
+        return log_imgs
+
+    if not log_domain:
+        K1, K2 = nx.exp(M1), nx.exp(M2)
+
+        def convol_imgs(imgs):
+            kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+            kxy = nx.einsum("...ij,klj->kli", K2, kx)
+            return kxy
+
+    return convol_imgs
 
 
 def convolutional_barycenter2d(
@@ -133,37 +161,26 @@ def _convolutional_barycenter2d(
     """
 
     A = list_to_array(A)
+    n_hists, width, height = A.shape
 
     nx = get_backend(A)
 
     if weights is None:
-        weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
+        weights = nx.ones((n_hists,), type_as=A) / n_hists
     else:
-        assert len(weights) == A.shape[0]
+        assert len(weights) == n_hists
 
     if log:
         log = {"err": []}
 
-    bar = nx.ones(A.shape[1:], type_as=A)
+    bar = nx.ones((width, height), type_as=A)
     bar /= nx.sum(bar)
     U = nx.ones(A.shape, type_as=A)
     V = nx.ones(A.shape, type_as=A)
     err = 1
 
     # build the convolution operator
-    # this is equivalent to blurring on horizontal then vertical directions
-    t = nx.linspace(0, 1, A.shape[1], type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    K1 = nx.exp(-((X - Y) ** 2) / reg)
-
-    t = nx.linspace(0, 1, A.shape[2], type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    K2 = nx.exp(-((X - Y) ** 2) / reg)
-
-    def convol_imgs(imgs):
-        kx = nx.einsum("...ij,kjl->kil", K1, imgs)
-        kxy = nx.einsum("...ij,klj->kli", K2, kx)
-        return kxy
+    convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)
 
     KU = convol_imgs(U)
     for ii in range(numItermax):
@@ -195,6 +212,7 @@ def convol_imgs(imgs):
     if log:
         log["niter"] = ii
         log["U"] = U
+        log["V"] = V
         return bar, log
     else:
         return bar
@@ -236,19 +254,7 @@ def _convolutional_barycenter2d_log(
 
     err = 1
     # build the convolution operator
-    # this is equivalent to blurring on horizontal then vertical directions
-    t = nx.linspace(0, 1, width, type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    M1 = -((X - Y) ** 2) / reg
-
-    t = nx.linspace(0, 1, height, type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    M2 = -((X - Y) ** 2) / reg
-
-    def convol_img(log_img):
-        log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
-        log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
-        return log_img
+    convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)
 
     logA = nx.log(A + stabThr)
     log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
@@ -417,23 +423,11 @@ def _convolutional_barycenter2d_debiased(
     bar /= width * height
     U = nx.ones(A.shape, type_as=A)
     V = nx.ones(A.shape, type_as=A)
-    c = nx.ones(A.shape[1:], type_as=A)
+    c = nx.ones((width, height), type_as=A)
     err = 1
 
     # build the convolution operator
-    # this is equivalent to blurring on horizontal then vertical directions
-    t = nx.linspace(0, 1, width, type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    K1 = nx.exp(-((X - Y) ** 2) / reg)
-
-    t = nx.linspace(0, 1, height, type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    K2 = nx.exp(-((X - Y) ** 2) / reg)
-
-    def convol_imgs(imgs):
-        kx = nx.einsum("...ij,kjl->kil", K1, imgs)
-        kxy = nx.einsum("...ij,klj->kli", K2, kx)
-        return kxy
+    convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)
 
     KU = convol_imgs(U)
     for ii in range(numItermax):
@@ -471,6 +465,7 @@ def convol_imgs(imgs):
     if log:
         log["niter"] = ii
         log["U"] = U
+        log["V"] = V
         return bar, log
     else:
         return bar
@@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log(
 
     err = 1
     # build the convolution operator
-    # this is equivalent to blurring on horizontal then vertical directions
-    t = nx.linspace(0, 1, width, type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    M1 = -((X - Y) ** 2) / reg
-
-    t = nx.linspace(0, 1, height, type_as=A)
-    [Y, X] = nx.meshgrid(t, t)
-    M2 = -((X - Y) ** 2) / reg
-
-    def convol_img(log_img):
-        log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
-        log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
-        return log_img
+    convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)
 
     logA = nx.log(A + stabThr)
     log_bar, c = nx.zeros((2, width, height), type_as=A)

From 698321eb2973a3c5490dae3d59751af70acb97d4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Fri, 27 Dec 2024 11:17:22 -0300
Subject: [PATCH 02/11] refactor: add warning msg

---
 ot/bregman/_convolutional.py | 34 +++++++++++-----------------------
 1 file changed, 11 insertions(+), 23 deletions(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index 9a9895d82..a3b999f22 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -13,11 +13,15 @@
 from ..backend import get_backend
 from ..utils import list_to_array
 
+_warning_msg = (
+    "Convolutional Sinkhorn did not converge. "
+    "Try a larger number of iterations `numItermax` "
+    "or a larger entropy `reg`."
+)
+
 
 def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
-    """
-    Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions.
-    """
+    """Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions."""
     t1 = nx.linspace(0, 1, width, type_as=type_as)
     Y1, X1 = nx.meshgrid(t1, t1)
     M1 = -((X1 - Y1) ** 2) / reg
@@ -204,11 +208,7 @@ def _convolutional_barycenter2d(
 
     else:
         if warn:
-            warnings.warn(
-                "Convolutional Sinkhorn did not converge. "
-                "Try a larger number of iterations `numItermax` "
-                "or a larger entropy `reg`."
-            )
+            warnings.warn(_warning_msg)
     if log:
         log["niter"] = ii
         log["U"] = U
@@ -282,11 +282,7 @@ def _convolutional_barycenter2d_log(
 
     else:
         if warn:
-            warnings.warn(
-                "Convolutional Sinkhorn did not converge. "
-                "Try a larger number of iterations `numItermax` "
-                "or a larger entropy `reg`."
-            )
+            warnings.warn(_warning_msg)
     if log:
         log["niter"] = ii
         return nx.exp(log_bar), log
@@ -457,11 +453,7 @@ def _convolutional_barycenter2d_debiased(
                 break
     else:
         if warn:
-            warnings.warn(
-                "Sinkhorn did not converge. You might want to "
-                "increase the number of iterations `numItermax` "
-                "or the regularization parameter `reg`."
-            )
+            warnings.warn(_warning_msg)
     if log:
         log["niter"] = ii
         log["U"] = U
@@ -534,11 +526,7 @@ def _convolutional_barycenter2d_debiased_log(
 
     else:
         if warn:
-            warnings.warn(
-                "Convolutional Sinkhorn did not converge. "
-                "Try a larger number of iterations `numItermax` "
-                "or a larger entropy `reg`."
-            )
+            warnings.warn(_warning_msg)
     if log:
         log["niter"] = ii
         return nx.exp(log_bar), log

From 2dd4610e9ef00725804143645a1630f6a7cea4cc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Fri, 27 Dec 2024 11:21:30 -0300
Subject: [PATCH 03/11] refactor: encapsulate the report printing in a function

---
 ot/bregman/_convolutional.py | 27 +++++++++++----------------
 1 file changed, 11 insertions(+), 16 deletions(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index a3b999f22..11a1c5903 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -46,6 +46,13 @@ def convol_imgs(imgs):
     return convol_imgs
 
 
+def _print_report(ii, err):
+    """Print the report of the iteration."""
+    if ii % 200 == 0:
+        print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
+    print("{:5d}|{:8e}|".format(ii, err))
+
+
 def convolutional_barycenter2d(
     A,
     reg,
@@ -198,11 +205,8 @@ def _convolutional_barycenter2d(
             # log and verbose print
             if log:
                 log["err"].append(err)
-
             if verbose:
-                if ii % 200 == 0:
-                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
-                print("{:5d}|{:8e}|".format(ii, err))
+                _print_report(ii, err)
             if err < stopThr:
                 break
 
@@ -271,11 +275,8 @@ def _convolutional_barycenter2d_log(
             # log and verbose print
             if log:
                 log["err"].append(err)
-
             if verbose:
-                if ii % 200 == 0:
-                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
-                print("{:5d}|{:8e}|".format(ii, err))
+                _print_report(ii, err)
             if err < stopThr:
                 break
         G = log_bar[None, :, :] - log_KU
@@ -441,11 +442,8 @@ def _convolutional_barycenter2d_debiased(
             # log and verbose print
             if log:
                 log["err"].append(err)
-
             if verbose:
-                if ii % 200 == 0:
-                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
-                print("{:5d}|{:8e}|".format(ii, err))
+                _print_report(ii, err)
 
             # debiased Sinkhorn does not converge monotonically
             # guarantee a few iterations are done before stopping
@@ -515,11 +513,8 @@ def _convolutional_barycenter2d_debiased_log(
             # log and verbose print
             if log:
                 log["err"].append(err)
-
             if verbose:
-                if ii % 200 == 0:
-                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
-                print("{:5d}|{:8e}|".format(ii, err))
+                _print_report(ii, err)
             if err < stopThr and ii > 20:
                 break
         G = log_bar[None, :, :] - log_KU

From 138bacb37091a142d40747bb26cc738d8ebaafe1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Fri, 27 Dec 2024 12:16:46 -0300
Subject: [PATCH 04/11] docs: add some documentation in the function
 _get_convol_img_fn

---
 ot/bregman/_convolutional.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index 11a1c5903..d4ad92466 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -21,7 +21,9 @@
 
 
 def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
-    """Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions."""
+    """Return the convolution operator for 2D images.
+
+    The function constructed is equivalent to blurring on horizontal then vertical directions."""
     t1 = nx.linspace(0, 1, width, type_as=type_as)
     Y1, X1 = nx.meshgrid(t1, t1)
     M1 = -((X1 - Y1) ** 2) / reg
@@ -30,11 +32,13 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
     Y2, X2 = nx.meshgrid(t2, t2)
     M2 = -((X2 - Y2) ** 2) / reg
 
+    # As M1 and M2 are computed first, we can use them to compute the convolution in log-domain
     def convol_imgs(log_imgs):
         log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
         log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
         return log_imgs
 
+    # If normal domain is selected, we can use M1 and M2 to compute the convolution
     if not log_domain:
         K1, K2 = nx.exp(M1), nx.exp(M2)
 

From dadb470bf861beb9f5269bf178ce4daf4b0e2b7c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Fri, 27 Dec 2024 12:19:29 -0300
Subject: [PATCH 05/11] docs: add realise

---
 RELEASES.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/RELEASES.md b/RELEASES.md
index f3f100b66..0ddac599b 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -40,6 +40,7 @@ This release also contains few bug fixes, concerning the support of any metric i
 - Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov  (PR #663)
 - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
 - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
+- Refactored `ot.bregman._convolutional` to improve readability (PR #709)
 
 #### Closed issues
 - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)

From 1030815e1b401336de2241211f71df386eb37a5e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <femunoz@dim.uchile.cl>
Date: Thu, 2 Jan 2025 13:34:48 -0300
Subject: [PATCH 06/11] refactor: change function _get_convol_img_fn for more
 clarity

---
 ot/bregman/_convolutional.py | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index d4ad92466..c8793cb25 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -31,13 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
     t2 = nx.linspace(0, 1, height, type_as=type_as)
     Y2, X2 = nx.meshgrid(t2, t2)
     M2 = -((X2 - Y2) ** 2) / reg
-
-    # As M1 and M2 are computed first, we can use them to compute the convolution in log-domain
-    def convol_imgs(log_imgs):
-        log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
-        log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
-        return log_imgs
-
+    
     # If normal domain is selected, we can use M1 and M2 to compute the convolution
     if not log_domain:
         K1, K2 = nx.exp(M1), nx.exp(M2)
@@ -47,6 +41,13 @@ def convol_imgs(imgs):
             kxy = nx.einsum("...ij,klj->kli", K2, kx)
             return kxy
 
+    # Else, we can use M1 and M2 to compute the convolution in log-domain
+    else:
+        def convol_imgs(log_imgs):
+            log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
+            log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
+            return log_imgs
+
     return convol_imgs
 
 

From a70cf8dc5f2d25184b4d0d6b8198b3f9b2722020 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Fri, 3 Jan 2025 09:02:14 -0300
Subject: [PATCH 07/11] refactor: run pre-commit

---
 ot/bregman/_convolutional.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index c8793cb25..412cca9e5 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -31,7 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
     t2 = nx.linspace(0, 1, height, type_as=type_as)
     Y2, X2 = nx.meshgrid(t2, t2)
     M2 = -((X2 - Y2) ** 2) / reg
-    
+
     # If normal domain is selected, we can use M1 and M2 to compute the convolution
     if not log_domain:
         K1, K2 = nx.exp(M1), nx.exp(M2)
@@ -43,6 +43,7 @@ def convol_imgs(imgs):
 
     # Else, we can use M1 and M2 to compute the convolution in log-domain
     else:
+
         def convol_imgs(log_imgs):
             log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
             log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T

From 80d954255c76d2bd240b0e403b64fd2c1a384fc7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Wed, 8 Jan 2025 20:10:38 -0300
Subject: [PATCH 08/11] test: refactor tests to delete the error for
 unavailable backends

---
 test/test_bregman.py | 236 ++++++++++++++++++-------------------------
 1 file changed, 99 insertions(+), 137 deletions(-)

diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6c0c0e8f2..d3dbb5b11 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -825,22 +825,18 @@ def test_wasserstein_bary_2d(nx, method):
 
     # wasserstein
     reg = 1e-2
-    if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
-        with pytest.raises(NotImplementedError):
-            ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
-    else:
-        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
-            A, reg, method=method, verbose=True, log=True
-        )
-        bary_wass = nx.to_numpy(
-            ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
-        )
+    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+        A, reg, method=method, verbose=True, log=True
+    )
+    bary_wass = nx.to_numpy(
+        ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+    )
 
-        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+    np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-        # help in checking if log and verbose do not bug the function
-        ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+    # help in checking if log and verbose do not bug the function
+    ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
 
 
 @pytest.skip_backend("tf")
@@ -856,27 +852,23 @@ def test_wasserstein_bary_2d_dtype_device(nx, method):
 
         # wasserstein
         reg = 1e-2
-        if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
-            with pytest.raises(NotImplementedError):
-                ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-        else:
-            # Compute the barycenter with numpy
-            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
-                A, reg, method=method, verbose=True, log=True
-            )
-            # Compute the barycenter with the backend
-            bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-            # Convert the backend result to numpy, to compare with the numpy result
-            bary_wass = nx.to_numpy(bary_wass_b)
+        # Compute the barycenter with numpy
+        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+            A, reg, method=method, verbose=True, log=True
+        )
+        # Compute the barycenter with the backend
+        bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+        # Convert the backend result to numpy, to compare with the numpy result
+        bary_wass = nx.to_numpy(bary_wass_b)
 
-            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-            # help in checking if log and verbose do not bug the function
-            ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+        # help in checking if log and verbose do not bug the function
+        ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
 
-            # Test that the dtype and device are the same after the computation
-            nx.assert_same_dtype_device(Ab, bary_wass_b)
+        # Test that the dtype and device are the same after the computation
+        nx.assert_same_dtype_device(Ab, bary_wass_b)
 
 
 @pytest.mark.skipif(not tf, reason="tf not installed")
@@ -894,37 +886,6 @@ def test_wasserstein_bary_2d_device_tf(method):
 
         # wasserstein
         reg = 1e-2
-        if method == "sinkhorn_log":
-            with pytest.raises(NotImplementedError):
-                ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-        else:
-            # Compute the barycenter with numpy
-            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
-                A, reg, method=method, verbose=True, log=True
-            )
-            # Compute the barycenter with the backend
-            bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-            # Convert the backend result to numpy, to compare with the numpy result
-            bary_wass = nx.to_numpy(bary_wass_b)
-
-            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
-
-            # help in checking if log and verbose do not bug the function
-            ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-
-            # Test that the dtype and device are the same after the computation
-            nx.assert_same_dtype_device(Ab, bary_wass_b)
-
-    # Check that everything happens on the GPU
-    Ab = nx.from_numpy(A)
-
-    # wasserstein
-    reg = 1e-2
-    if method == "sinkhorn_log":
-        with pytest.raises(NotImplementedError):
-            ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-    else:
         # Compute the barycenter with numpy
         bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
             A, reg, method=method, verbose=True, log=True
@@ -943,9 +904,32 @@ def test_wasserstein_bary_2d_device_tf(method):
         # Test that the dtype and device are the same after the computation
         nx.assert_same_dtype_device(Ab, bary_wass_b)
 
-        # Check this only if GPU is available
-        if len(tf.config.list_physical_devices("GPU")) > 0:
-            assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
+    # Check that everything happens on the GPU
+    Ab = nx.from_numpy(A)
+
+    # wasserstein
+    reg = 1e-2
+    # Compute the barycenter with numpy
+    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+        A, reg, method=method, verbose=True, log=True
+    )
+    # Compute the barycenter with the backend
+    bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+    # Convert the backend result to numpy, to compare with the numpy result
+    bary_wass = nx.to_numpy(bary_wass_b)
+
+    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+    np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+    # help in checking if log and verbose do not bug the function
+    ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+    # Test that the dtype and device are the same after the computation
+    nx.assert_same_dtype_device(Ab, bary_wass_b)
+
+    # Check this only if GPU is available
+    if len(tf.config.list_physical_devices("GPU")) > 0:
+        assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
 
 
 @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -957,22 +941,18 @@ def test_wasserstein_bary_2d_debiased(nx, method):
 
     # wasserstein
     reg = 1e-2
-    if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
-        with pytest.raises(NotImplementedError):
-            ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
-    else:
-        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
-            A, reg, method=method, verbose=True, log=True
-        )
-        bary_wass = nx.to_numpy(
-            ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
-        )
+    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+        A, reg, method=method, verbose=True, log=True
+    )
+    bary_wass = nx.to_numpy(
+        ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+    )
 
-        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+    np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-        # help in checking if log and verbose do not bug the function
-        ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
+    # help in checking if log and verbose do not bug the function
+    ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
 
 
 @pytest.skip_backend("tf")
@@ -988,31 +968,25 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method):
 
         # wasserstein
         reg = 1e-2
-        if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
-            with pytest.raises(NotImplementedError):
-                ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
-        else:
-            # Compute the barycenter with numpy
-            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
-                A, reg, method=method, verbose=True, log=True
-            )
-            # Compute the barycenter with the backend
-            bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
-                Ab, reg, method=method
-            )
-            # Convert the backend result to numpy, to compare with the numpy result
-            bary_wass = nx.to_numpy(bary_wass_b)
+        # Compute the barycenter with numpy
+        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+            A, reg, method=method, verbose=True, log=True
+        )
+        # Compute the barycenter with the backend
+        bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
+            Ab, reg, method=method
+        )
+        # Convert the backend result to numpy, to compare with the numpy result
+        bary_wass = nx.to_numpy(bary_wass_b)
 
-            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-            # help in checking if log and verbose do not bug the function
-            ot.bregman.convolutional_barycenter2d_debiased(
-                A, reg, log=True, verbose=True
-            )
+        # help in checking if log and verbose do not bug the function
+        ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
 
-            # Test that the dtype and device are the same after the computation
-            nx.assert_same_dtype_device(Ab, bary_wass_b)
+        # Test that the dtype and device are the same after the computation
+        nx.assert_same_dtype_device(Ab, bary_wass_b)
 
 
 @pytest.mark.skipif(not tf, reason="tf not installed")
@@ -1030,41 +1004,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
 
         # wasserstein
         reg = 1e-2
-        if method == "sinkhorn_log":
-            with pytest.raises(NotImplementedError):
-                ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
-        else:
-            # Compute the barycenter with numpy
-            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
-                A, reg, method=method, verbose=True, log=True
-            )
-            # Compute the barycenter with the backend
-            bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
-                Ab, reg, method=method
-            )
-            # Convert the backend result to numpy, to compare with the numpy result
-            bary_wass = nx.to_numpy(bary_wass_b)
-
-            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
-
-            # help in checking if log and verbose do not bug the function
-            ot.bregman.convolutional_barycenter2d_debiased(
-                A, reg, log=True, verbose=True
-            )
-
-            # Test that the dtype and device are the same after the computation
-            nx.assert_same_dtype_device(Ab, bary_wass_b)
-
-    # Check that everything happens on the GPU
-    Ab = nx.from_numpy(A)
-
-    # wasserstein
-    reg = 1e-2
-    if method == "sinkhorn_log":
-        with pytest.raises(NotImplementedError):
-            ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
-    else:
         # Compute the barycenter with numpy
         bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
             A, reg, method=method, verbose=True, log=True
@@ -1077,6 +1016,29 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
         bary_wass = nx.to_numpy(bary_wass_b)
 
         np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+        # help in checking if log and verbose do not bug the function
+        ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
+
+        # Test that the dtype and device are the same after the computation
+        nx.assert_same_dtype_device(Ab, bary_wass_b)
+
+    # Check that everything happens on the GPU
+    Ab = nx.from_numpy(A)
+
+    # wasserstein
+    reg = 1e-2
+    # Compute the barycenter with numpy
+    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+        A, reg, method=method, verbose=True, log=True
+    )
+    # Compute the barycenter with the backend
+    bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
+    # Convert the backend result to numpy, to compare with the numpy result
+    bary_wass = nx.to_numpy(bary_wass_b)
+
+    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
 
 
 def test_unmix(nx):

From e8d4de5c105df7dcc79c3ac9a9e3784b46ce3ac9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Wed, 8 Jan 2025 20:12:42 -0300
Subject: [PATCH 09/11] feat: delete not implemented error in convolutional
 module

---
 ot/bregman/_convolutional.py | 11 +----------
 1 file changed, 1 insertion(+), 10 deletions(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index 412cca9e5..6f0de8330 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -246,11 +246,6 @@ def _convolutional_barycenter2d_log(
     A = list_to_array(A)
 
     nx = get_backend(A)
-    if nx.__name__ in ("jax", "tf"):
-        raise NotImplementedError(
-            "Log-domain functions are not yet implemented"
-            " for Jax and TF. Use numpy or torch arrays instead."
-        )
 
     n_hists, width, height = A.shape
 
@@ -483,11 +478,7 @@ def _convolutional_barycenter2d_debiased_log(
     A = list_to_array(A)
     n_hists, width, height = A.shape
     nx = get_backend(A)
-    if nx.__name__ in ("jax", "tf"):
-        raise NotImplementedError(
-            "Log-domain functions are not yet implemented"
-            " for Jax and TF. Use numpy or torch arrays instead."
-        )
+
     if weights is None:
         weights = nx.ones((n_hists,), type_as=A) / n_hists
     else:

From b0c5ef335a4aa8e9063f34cb39acc52a028f27d0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Wed, 8 Jan 2025 22:21:21 -0300
Subject: [PATCH 10/11] revert the last two commits

---
 ot/bregman/_convolutional.py |  11 +-
 test/test_bregman.py         | 236 ++++++++++++++++++++---------------
 2 files changed, 147 insertions(+), 100 deletions(-)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index 6f0de8330..412cca9e5 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -246,6 +246,11 @@ def _convolutional_barycenter2d_log(
     A = list_to_array(A)
 
     nx = get_backend(A)
+    if nx.__name__ in ("jax", "tf"):
+        raise NotImplementedError(
+            "Log-domain functions are not yet implemented"
+            " for Jax and TF. Use numpy or torch arrays instead."
+        )
 
     n_hists, width, height = A.shape
 
@@ -478,7 +483,11 @@ def _convolutional_barycenter2d_debiased_log(
     A = list_to_array(A)
     n_hists, width, height = A.shape
     nx = get_backend(A)
-
+    if nx.__name__ in ("jax", "tf"):
+        raise NotImplementedError(
+            "Log-domain functions are not yet implemented"
+            " for Jax and TF. Use numpy or torch arrays instead."
+        )
     if weights is None:
         weights = nx.ones((n_hists,), type_as=A) / n_hists
     else:
diff --git a/test/test_bregman.py b/test/test_bregman.py
index d3dbb5b11..6c0c0e8f2 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -825,18 +825,22 @@ def test_wasserstein_bary_2d(nx, method):
 
     # wasserstein
     reg = 1e-2
-    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
-        A, reg, method=method, verbose=True, log=True
-    )
-    bary_wass = nx.to_numpy(
-        ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
-    )
+    if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+        with pytest.raises(NotImplementedError):
+            ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+    else:
+        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+            A, reg, method=method, verbose=True, log=True
+        )
+        bary_wass = nx.to_numpy(
+            ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+        )
 
-    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-    np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-    # help in checking if log and verbose do not bug the function
-    ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+        # help in checking if log and verbose do not bug the function
+        ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
 
 
 @pytest.skip_backend("tf")
@@ -852,23 +856,27 @@ def test_wasserstein_bary_2d_dtype_device(nx, method):
 
         # wasserstein
         reg = 1e-2
-        # Compute the barycenter with numpy
-        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
-            A, reg, method=method, verbose=True, log=True
-        )
-        # Compute the barycenter with the backend
-        bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-        # Convert the backend result to numpy, to compare with the numpy result
-        bary_wass = nx.to_numpy(bary_wass_b)
+        if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+            with pytest.raises(NotImplementedError):
+                ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+        else:
+            # Compute the barycenter with numpy
+            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+                A, reg, method=method, verbose=True, log=True
+            )
+            # Compute the barycenter with the backend
+            bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+            # Convert the backend result to numpy, to compare with the numpy result
+            bary_wass = nx.to_numpy(bary_wass_b)
 
-        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-        # help in checking if log and verbose do not bug the function
-        ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+            # help in checking if log and verbose do not bug the function
+            ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
 
-        # Test that the dtype and device are the same after the computation
-        nx.assert_same_dtype_device(Ab, bary_wass_b)
+            # Test that the dtype and device are the same after the computation
+            nx.assert_same_dtype_device(Ab, bary_wass_b)
 
 
 @pytest.mark.skipif(not tf, reason="tf not installed")
@@ -886,6 +894,37 @@ def test_wasserstein_bary_2d_device_tf(method):
 
         # wasserstein
         reg = 1e-2
+        if method == "sinkhorn_log":
+            with pytest.raises(NotImplementedError):
+                ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+        else:
+            # Compute the barycenter with numpy
+            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+                A, reg, method=method, verbose=True, log=True
+            )
+            # Compute the barycenter with the backend
+            bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+            # Convert the backend result to numpy, to compare with the numpy result
+            bary_wass = nx.to_numpy(bary_wass_b)
+
+            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+            # help in checking if log and verbose do not bug the function
+            ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+            # Test that the dtype and device are the same after the computation
+            nx.assert_same_dtype_device(Ab, bary_wass_b)
+
+    # Check that everything happens on the GPU
+    Ab = nx.from_numpy(A)
+
+    # wasserstein
+    reg = 1e-2
+    if method == "sinkhorn_log":
+        with pytest.raises(NotImplementedError):
+            ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
+    else:
         # Compute the barycenter with numpy
         bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
             A, reg, method=method, verbose=True, log=True
@@ -904,32 +943,9 @@ def test_wasserstein_bary_2d_device_tf(method):
         # Test that the dtype and device are the same after the computation
         nx.assert_same_dtype_device(Ab, bary_wass_b)
 
-    # Check that everything happens on the GPU
-    Ab = nx.from_numpy(A)
-
-    # wasserstein
-    reg = 1e-2
-    # Compute the barycenter with numpy
-    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
-        A, reg, method=method, verbose=True, log=True
-    )
-    # Compute the barycenter with the backend
-    bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
-    # Convert the backend result to numpy, to compare with the numpy result
-    bary_wass = nx.to_numpy(bary_wass_b)
-
-    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-    np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
-
-    # help in checking if log and verbose do not bug the function
-    ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-
-    # Test that the dtype and device are the same after the computation
-    nx.assert_same_dtype_device(Ab, bary_wass_b)
-
-    # Check this only if GPU is available
-    if len(tf.config.list_physical_devices("GPU")) > 0:
-        assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
+        # Check this only if GPU is available
+        if len(tf.config.list_physical_devices("GPU")) > 0:
+            assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
 
 
 @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -941,18 +957,22 @@ def test_wasserstein_bary_2d_debiased(nx, method):
 
     # wasserstein
     reg = 1e-2
-    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
-        A, reg, method=method, verbose=True, log=True
-    )
-    bary_wass = nx.to_numpy(
-        ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
-    )
+    if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+        with pytest.raises(NotImplementedError):
+            ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+    else:
+        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+            A, reg, method=method, verbose=True, log=True
+        )
+        bary_wass = nx.to_numpy(
+            ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+        )
 
-    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-    np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-    # help in checking if log and verbose do not bug the function
-    ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
+        # help in checking if log and verbose do not bug the function
+        ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
 
 
 @pytest.skip_backend("tf")
@@ -968,25 +988,31 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method):
 
         # wasserstein
         reg = 1e-2
-        # Compute the barycenter with numpy
-        bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
-            A, reg, method=method, verbose=True, log=True
-        )
-        # Compute the barycenter with the backend
-        bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
-            Ab, reg, method=method
-        )
-        # Convert the backend result to numpy, to compare with the numpy result
-        bary_wass = nx.to_numpy(bary_wass_b)
+        if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+            with pytest.raises(NotImplementedError):
+                ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
+        else:
+            # Compute the barycenter with numpy
+            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+                A, reg, method=method, verbose=True, log=True
+            )
+            # Compute the barycenter with the backend
+            bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
+                Ab, reg, method=method
+            )
+            # Convert the backend result to numpy, to compare with the numpy result
+            bary_wass = nx.to_numpy(bary_wass_b)
 
-        np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
 
-        # help in checking if log and verbose do not bug the function
-        ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
+            # help in checking if log and verbose do not bug the function
+            ot.bregman.convolutional_barycenter2d_debiased(
+                A, reg, log=True, verbose=True
+            )
 
-        # Test that the dtype and device are the same after the computation
-        nx.assert_same_dtype_device(Ab, bary_wass_b)
+            # Test that the dtype and device are the same after the computation
+            nx.assert_same_dtype_device(Ab, bary_wass_b)
 
 
 @pytest.mark.skipif(not tf, reason="tf not installed")
@@ -1004,6 +1030,41 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
 
         # wasserstein
         reg = 1e-2
+        if method == "sinkhorn_log":
+            with pytest.raises(NotImplementedError):
+                ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
+        else:
+            # Compute the barycenter with numpy
+            bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+                A, reg, method=method, verbose=True, log=True
+            )
+            # Compute the barycenter with the backend
+            bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
+                Ab, reg, method=method
+            )
+            # Convert the backend result to numpy, to compare with the numpy result
+            bary_wass = nx.to_numpy(bary_wass_b)
+
+            np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+            np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+            # help in checking if log and verbose do not bug the function
+            ot.bregman.convolutional_barycenter2d_debiased(
+                A, reg, log=True, verbose=True
+            )
+
+            # Test that the dtype and device are the same after the computation
+            nx.assert_same_dtype_device(Ab, bary_wass_b)
+
+    # Check that everything happens on the GPU
+    Ab = nx.from_numpy(A)
+
+    # wasserstein
+    reg = 1e-2
+    if method == "sinkhorn_log":
+        with pytest.raises(NotImplementedError):
+            ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
+    else:
         # Compute the barycenter with numpy
         bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
             A, reg, method=method, verbose=True, log=True
@@ -1016,29 +1077,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
         bary_wass = nx.to_numpy(bary_wass_b)
 
         np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
-        np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
-
-        # help in checking if log and verbose do not bug the function
-        ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
-
-        # Test that the dtype and device are the same after the computation
-        nx.assert_same_dtype_device(Ab, bary_wass_b)
-
-    # Check that everything happens on the GPU
-    Ab = nx.from_numpy(A)
-
-    # wasserstein
-    reg = 1e-2
-    # Compute the barycenter with numpy
-    bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
-        A, reg, method=method, verbose=True, log=True
-    )
-    # Compute the barycenter with the backend
-    bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
-    # Convert the backend result to numpy, to compare with the numpy result
-    bary_wass = nx.to_numpy(bary_wass_b)
-
-    np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
 
 
 def test_unmix(nx):

From 667ea6de4823c579f77636917010c897ac42fea0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= <fmunoz@ug.uchile.cl>
Date: Wed, 8 Jan 2025 22:22:28 -0300
Subject: [PATCH 11/11] docs: add comments with the reason of the error

---
 ot/bregman/_convolutional.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py
index 412cca9e5..9a8253240 100644
--- a/ot/bregman/_convolutional.py
+++ b/ot/bregman/_convolutional.py
@@ -246,6 +246,8 @@ def _convolutional_barycenter2d_log(
     A = list_to_array(A)
 
     nx = get_backend(A)
+    # This error is raised because we are using mutable assignment in the line
+    #  `log_KU[k] = ...` which is not allowed in Jax and TF.
     if nx.__name__ in ("jax", "tf"):
         raise NotImplementedError(
             "Log-domain functions are not yet implemented"
@@ -483,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log(
     A = list_to_array(A)
     n_hists, width, height = A.shape
     nx = get_backend(A)
+    # This error is raised because we are using mutable assignment in the line
+    #  `log_KU[k] = ...` which is not allowed in Jax and TF.
     if nx.__name__ in ("jax", "tf"):
         raise NotImplementedError(
             "Log-domain functions are not yet implemented"