Skip to content

Commit

Permalink
[MRG] Refactor: improve readibility in the convolutional module (#709)
Browse files Browse the repository at this point in the history
* feat: add _get_convol_img_fn

* refactor: add warning msg

* refactor: encapsulate the report printing in a function

* docs: add some documentation in the function _get_convol_img_fn

* docs: add realise

* refactor: change function _get_convol_img_fn for more clarity

* refactor: run pre-commit

* test: refactor tests to delete the error for unavailable backends

* feat: delete not implemented error in convolutional module

* revert the last two commits

* docs: add comments with the reason of the error

---------

Co-authored-by: Francisco Muñoz <[email protected]>
  • Loading branch information
framunoz and Francisco Muñoz authored Jan 10, 2025
1 parent 39cd6ec commit 200322b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 93 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ This release also contains few bug fixes, concerning the support of any metric i
- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663)
- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
- Refactored `ot.bregman._convolutional` to improve readability (PR #709)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Expand Down
162 changes: 69 additions & 93 deletions ot/bregman/_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,53 @@

import warnings

from ..utils import list_to_array
from ..backend import get_backend
from ..utils import list_to_array

_warning_msg = (
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)


def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
"""Return the convolution operator for 2D images.
The function constructed is equivalent to blurring on horizontal then vertical directions."""
t1 = nx.linspace(0, 1, width, type_as=type_as)
Y1, X1 = nx.meshgrid(t1, t1)
M1 = -((X1 - Y1) ** 2) / reg

t2 = nx.linspace(0, 1, height, type_as=type_as)
Y2, X2 = nx.meshgrid(t2, t2)
M2 = -((X2 - Y2) ** 2) / reg

# If normal domain is selected, we can use M1 and M2 to compute the convolution
if not log_domain:
K1, K2 = nx.exp(M1), nx.exp(M2)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

# Else, we can use M1 and M2 to compute the convolution in log-domain
else:

def convol_imgs(log_imgs):
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
return log_imgs

return convol_imgs


def _print_report(ii, err):
"""Print the report of the iteration."""
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))


def convolutional_barycenter2d(
Expand Down Expand Up @@ -133,37 +178,26 @@ def _convolutional_barycenter2d(
"""

A = list_to_array(A)
n_hists, width, height = A.shape

nx = get_backend(A)

if weights is None:
weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert len(weights) == A.shape[0]
assert len(weights) == n_hists

if log:
log = {"err": []}

bar = nx.ones(A.shape[1:], type_as=A)
bar = nx.ones((width, height), type_as=A)
bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-((X - Y) ** 2) / reg)

t = nx.linspace(0, 1, A.shape[2], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-((X - Y) ** 2) / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy
convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)

KU = convol_imgs(U)
for ii in range(numItermax):
Expand All @@ -177,24 +211,18 @@ def convol_imgs(imgs):
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)
if err < stopThr:
break

else:
if warn:
warnings.warn(
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
log["U"] = U
log["V"] = V
return bar, log
else:
return bar
Expand All @@ -218,6 +246,8 @@ def _convolutional_barycenter2d_log(
A = list_to_array(A)

nx = get_backend(A)
# This error is raised because we are using mutable assignment in the line
# `log_KU[k] = ...` which is not allowed in Jax and TF.
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
Expand All @@ -236,19 +266,7 @@ def _convolutional_barycenter2d_log(

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = -((X - Y) ** 2) / reg

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = -((X - Y) ** 2) / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img
convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)

logA = nx.log(A + stabThr)
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
Expand All @@ -265,22 +283,15 @@ def convol_img(log_img):
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)
if err < stopThr:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn(
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
return nx.exp(log_bar), log
Expand Down Expand Up @@ -417,23 +428,11 @@ def _convolutional_barycenter2d_debiased(
bar /= width * height
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
c = nx.ones(A.shape[1:], type_as=A)
c = nx.ones((width, height), type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-((X - Y) ** 2) / reg)

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-((X - Y) ** 2) / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy
convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)

KU = convol_imgs(U)
for ii in range(numItermax):
Expand All @@ -451,26 +450,20 @@ def convol_imgs(imgs):
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)

# debiased Sinkhorn does not converge monotonically
# guarantee a few iterations are done before stopping
if err < stopThr and ii > 20:
break
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
log["U"] = U
log["V"] = V
return bar, log
else:
return bar
Expand All @@ -492,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log(
A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
# This error is raised because we are using mutable assignment in the line
# `log_KU[k] = ...` which is not allowed in Jax and TF.
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
Expand All @@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log(

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = -((X - Y) ** 2) / reg

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = -((X - Y) ** 2) / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img
convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)

logA = nx.log(A + stabThr)
log_bar, c = nx.zeros((2, width, height), type_as=A)
Expand All @@ -540,22 +523,15 @@ def convol_img(log_img):
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)
if err < stopThr and ii > 20:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn(
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
return nx.exp(log_bar), log
Expand Down

0 comments on commit 200322b

Please sign in to comment.