diff --git a/RELEASES.md b/RELEASES.md index f3f100b66..0ddac599b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -40,6 +40,7 @@ This release also contains few bug fixes, concerning the support of any metric i - Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663) - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) +- Refactored `ot.bregman._convolutional` to improve readability (PR #709) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 0e6548710..9a8253240 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -10,8 +10,53 @@ import warnings -from ..utils import list_to_array from ..backend import get_backend +from ..utils import list_to_array + +_warning_msg = ( + "Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`." +) + + +def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): + """Return the convolution operator for 2D images. + + The function constructed is equivalent to blurring on horizontal then vertical directions.""" + t1 = nx.linspace(0, 1, width, type_as=type_as) + Y1, X1 = nx.meshgrid(t1, t1) + M1 = -((X1 - Y1) ** 2) / reg + + t2 = nx.linspace(0, 1, height, type_as=type_as) + Y2, X2 = nx.meshgrid(t2, t2) + M2 = -((X2 - Y2) ** 2) / reg + + # If normal domain is selected, we can use M1 and M2 to compute the convolution + if not log_domain: + K1, K2 = nx.exp(M1), nx.exp(M2) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + # Else, we can use M1 and M2 to compute the convolution in log-domain + else: + + def convol_imgs(log_imgs): + log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) + log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T + return log_imgs + + return convol_imgs + + +def _print_report(ii, err): + """Print the report of the iteration.""" + if ii % 200 == 0: + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) def convolutional_barycenter2d( @@ -133,37 +178,26 @@ def _convolutional_barycenter2d( """ A = list_to_array(A) + n_hists, width, height = A.shape nx = get_backend(A) if weights is None: - weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] + weights = nx.ones((n_hists,), type_as=A) / n_hists else: - assert len(weights) == A.shape[0] + assert len(weights) == n_hists if log: log = {"err": []} - bar = nx.ones(A.shape[1:], type_as=A) + bar = nx.ones((width, height), type_as=A) bar /= nx.sum(bar) U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, A.shape[1], type_as=A) - [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-((X - Y) ** 2) / reg) - - t = nx.linspace(0, 1, A.shape[2], type_as=A) - [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-((X - Y) ** 2) / reg) - - def convol_imgs(imgs): - kx = nx.einsum("...ij,kjl->kil", K1, imgs) - kxy = nx.einsum("...ij,klj->kli", K2, kx) - return kxy + convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A) KU = convol_imgs(U) for ii in range(numItermax): @@ -177,24 +211,18 @@ def convol_imgs(imgs): # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) if err < stopThr: break else: if warn: - warnings.warn( - "Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii log["U"] = U + log["V"] = V return bar, log else: return bar @@ -218,6 +246,8 @@ def _convolutional_barycenter2d_log( A = list_to_array(A) nx = get_backend(A) + # This error is raised because we are using mutable assignment in the line + # `log_KU[k] = ...` which is not allowed in Jax and TF. if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" @@ -236,19 +266,7 @@ def _convolutional_barycenter2d_log( err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M1 = -((X - Y) ** 2) / reg - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M2 = -((X - Y) ** 2) / reg - - def convol_img(log_img): - log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) - log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T - return log_img + convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True) logA = nx.log(A + stabThr) log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) @@ -265,22 +283,15 @@ def convol_img(log_img): # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) if err < stopThr: break G = log_bar[None, :, :] - log_KU else: if warn: - warnings.warn( - "Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii return nx.exp(log_bar), log @@ -417,23 +428,11 @@ def _convolutional_barycenter2d_debiased( bar /= width * height U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) - c = nx.ones(A.shape[1:], type_as=A) + c = nx.ones((width, height), type_as=A) err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-((X - Y) ** 2) / reg) - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-((X - Y) ** 2) / reg) - - def convol_imgs(imgs): - kx = nx.einsum("...ij,kjl->kil", K1, imgs) - kxy = nx.einsum("...ij,klj->kli", K2, kx) - return kxy + convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A) KU = convol_imgs(U) for ii in range(numItermax): @@ -451,11 +450,8 @@ def convol_imgs(imgs): # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping @@ -463,14 +459,11 @@ def convol_imgs(imgs): break else: if warn: - warnings.warn( - "Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii log["U"] = U + log["V"] = V return bar, log else: return bar @@ -492,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log( A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) + # This error is raised because we are using mutable assignment in the line + # `log_KU[k] = ...` which is not allowed in Jax and TF. if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" @@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log( err = 1 # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M1 = -((X - Y) ** 2) / reg - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M2 = -((X - Y) ** 2) / reg - - def convol_img(log_img): - log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) - log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T - return log_img + convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True) logA = nx.log(A + stabThr) log_bar, c = nx.zeros((2, width, height), type_as=A) @@ -540,22 +523,15 @@ def convol_img(log_img): # log and verbose print if log: log["err"].append(err) - if verbose: - if ii % 200 == 0: - print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) - print("{:5d}|{:8e}|".format(ii, err)) + _print_report(ii, err) if err < stopThr and ii > 20: break G = log_bar[None, :, :] - log_KU else: if warn: - warnings.warn( - "Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`." - ) + warnings.warn(_warning_msg) if log: log["niter"] = ii return nx.exp(log_bar), log