Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Refactor: improve readibility in the convolutional module #709

1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ This release also contains few bug fixes, concerning the support of any metric i
- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663)
- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
- Refactored `ot.bregman._convolutional` to improve readability (PR #709)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Expand Down
162 changes: 69 additions & 93 deletions ot/bregman/_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,53 @@

import warnings

from ..utils import list_to_array
from ..backend import get_backend
from ..utils import list_to_array

_warning_msg = (
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)


def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
"""Return the convolution operator for 2D images.

The function constructed is equivalent to blurring on horizontal then vertical directions."""
t1 = nx.linspace(0, 1, width, type_as=type_as)
Y1, X1 = nx.meshgrid(t1, t1)
M1 = -((X1 - Y1) ** 2) / reg

t2 = nx.linspace(0, 1, height, type_as=type_as)
Y2, X2 = nx.meshgrid(t2, t2)
M2 = -((X2 - Y2) ** 2) / reg

# If normal domain is selected, we can use M1 and M2 to compute the convolution
if not log_domain:
K1, K2 = nx.exp(M1), nx.exp(M2)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

# Else, we can use M1 and M2 to compute the convolution in log-domain
else:

def convol_imgs(log_imgs):
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
return log_imgs

return convol_imgs


def _print_report(ii, err):
"""Print the report of the iteration."""
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)

Check warning on line 58 in ot/bregman/_convolutional.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman/_convolutional.py#L58

Added line #L58 was not covered by tests
print("{:5d}|{:8e}|".format(ii, err))


def convolutional_barycenter2d(
Expand Down Expand Up @@ -133,37 +178,26 @@
"""

A = list_to_array(A)
n_hists, width, height = A.shape

nx = get_backend(A)

if weights is None:
weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert len(weights) == A.shape[0]
assert len(weights) == n_hists

if log:
log = {"err": []}

bar = nx.ones(A.shape[1:], type_as=A)
bar = nx.ones((width, height), type_as=A)
bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-((X - Y) ** 2) / reg)

t = nx.linspace(0, 1, A.shape[2], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-((X - Y) ** 2) / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy
convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)

KU = convol_imgs(U)
for ii in range(numItermax):
Expand All @@ -177,24 +211,18 @@
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)
if err < stopThr:
break

else:
if warn:
warnings.warn(
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
log["U"] = U
log["V"] = V
return bar, log
else:
return bar
Expand All @@ -218,6 +246,8 @@
A = list_to_array(A)

nx = get_backend(A)
# This error is raised because we are using mutable assignment in the line
# `log_KU[k] = ...` which is not allowed in Jax and TF.
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
Expand All @@ -236,19 +266,7 @@

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = -((X - Y) ** 2) / reg

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = -((X - Y) ** 2) / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img
convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)

logA = nx.log(A + stabThr)
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
Expand All @@ -265,22 +283,15 @@
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)
if err < stopThr:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn(
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
return nx.exp(log_bar), log
Expand Down Expand Up @@ -417,23 +428,11 @@
bar /= width * height
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
c = nx.ones(A.shape[1:], type_as=A)
c = nx.ones((width, height), type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-((X - Y) ** 2) / reg)

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-((X - Y) ** 2) / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy
convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)

KU = convol_imgs(U)
for ii in range(numItermax):
Expand All @@ -451,26 +450,20 @@
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)

# debiased Sinkhorn does not converge monotonically
# guarantee a few iterations are done before stopping
if err < stopThr and ii > 20:
break
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
log["U"] = U
log["V"] = V
return bar, log
else:
return bar
Expand All @@ -492,6 +485,8 @@
A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
# This error is raised because we are using mutable assignment in the line
# `log_KU[k] = ...` which is not allowed in Jax and TF.
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
Expand All @@ -507,19 +502,7 @@

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = -((X - Y) ** 2) / reg

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = -((X - Y) ** 2) / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img
convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)

logA = nx.log(A + stabThr)
log_bar, c = nx.zeros((2, width, height), type_as=A)
Expand All @@ -540,22 +523,15 @@
# log and verbose print
if log:
log["err"].append(err)

if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
_print_report(ii, err)
if err < stopThr and ii > 20:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn(
"Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations `numItermax` "
"or a larger entropy `reg`."
)
warnings.warn(_warning_msg)
if log:
log["niter"] = ii
return nx.exp(log_bar), log
Expand Down
Loading