Skip to content

Commit

Permalink
Merge pull request #22 from deel-ai/feat/update_normalizers
Browse files Browse the repository at this point in the history
Pytest unit tests, new losses support, and normalization enhancement
  • Loading branch information
cofri authored Nov 27, 2024
2 parents 49b1a50 + 6d7d3a1 commit 4360edf
Show file tree
Hide file tree
Showing 43 changed files with 6,873 additions and 1,515 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/generic_bug_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ body:
description: |
examples:
- **OS**: linux
- **Python version**: 3.7
- **Python version**: 3.11
- **Packages used version**: PyTorch, Numpy, scikit-learn, etc..
value: |
- OS:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-lints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.9, "3.10", 3.11]

steps:
- uses: actions/checkout@v1
Expand Down
12 changes: 9 additions & 3 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.7, 3.8, 3.9]
include:
- python-version: 3.9
pt-version: 1.13.1
- python-version: "3.10"
pt-version: 2.1.2
- python-version: "3.11"
pt-version: 2.4.1

steps:
- uses: actions/checkout@v1
Expand All @@ -20,5 +26,5 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install tox
- name: Test with tox
run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .)
- name: Test with tox (Python ${{ matrix.python-version }} - Torch ${{ matrix.pt-version }})
run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .)-pt$(echo ${{ matrix.pt-version }})
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ __pycache__
build/
dist/

.github/

# Test files
logs/*
Expand Down
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include deel/torchlip/VERSION
include LICENSE
1 change: 1 addition & 0 deletions deel/torchlip/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.2
6 changes: 5 additions & 1 deletion deel/torchlip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
# =====================================================================================
# flake8: noqa

__version__ = "0.1.1"
from os import path

with open(path.join(path.dirname(__file__), "VERSION")) as f:
__version__ = f.read().strip()

from . import functional, init, normalizers, utils
from .modules import *
Expand All @@ -39,6 +42,7 @@
"GroupSort2",
"HKRLoss",
"HKRMulticlassLoss",
"SoftHKRMulticlassLoss",
"HingeMarginLoss",
"HingeMulticlassLoss",
"InvertibleDownSampling",
Expand Down
206 changes: 145 additions & 61 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,50 @@ def lipschitz_prelu(


# Losses
def apply_reduction(val: torch.Tensor, reduction: str) -> torch.Tensor:
if reduction == "auto":
reduction = "mean"
red = getattr(torch, reduction, None)
if red is None:
return val
return red(val)


def kr_loss(
input: torch.Tensor,
target: torch.Tensor,
true_values: Tuple[int, int] = (0, 1),
) -> torch.Tensor:
def kr_loss(input: torch.Tensor, target: torch.Tensor, multi_gpu=False) -> torch.Tensor:
r"""
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
as per
.. math::
\mathcal{W}(\mu, \nu) = \sup\limits_{f\in{}Lip_1(\Omega)}
\underset{\mathbf{x}\sim{}\mu}{\mathbb{E}}[f(\mathbf{x})]
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
two possible labels as specific by their sign.
`target` accepts label values in (0, 1), (-1, 1), or pre-processed with the
`deel.torchlip.functional.process_labels_for_multi_gpu()` function.
Using a multi-GPU/TPU strategy requires to set `multi_gpu` to True and to
pre-process the labels `target` with the
`deel.torchlip.functional.process_labels_for_multi_gpu()` function.
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The Wasserstein-1 loss between ``input`` and ``target``.
"""
if multi_gpu:
return kr_loss_multi_gpu(input, target)
else:
return kr_loss_standard(input, target)


def kr_loss_standard(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
as per
Expand All @@ -295,26 +332,62 @@ def kr_loss(
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
two possible labels as specific by ``true_values``.
two possible labels as specific by their sign.
`target` accepts label values in (0, 1), (-1, 1)
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
true_values: Tuple containing the two label for the predicted class.
Returns:
The Wasserstein-1 loss between ``input`` and ``target``.
"""

v0, v1 = true_values
target = target.view(input.shape)
return torch.mean(input[target == v0]) - torch.mean(input[target == v1])
pos_target = (target > 0).to(input.dtype)
mean_pos = torch.mean(pos_target, dim=0)
# pos factor = batch_size/number of positive samples
pos_factor = torch.nan_to_num(1.0 / mean_pos)
# neg factor = batch_size/number of negative samples
neg_factor = -torch.nan_to_num(1.0 / (1.0 - mean_pos))

weighted_input = torch.where(target > 0, pos_factor, neg_factor) * input
# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
# In binary case (`y_true` of shape (batch_size, 1)), `tf.reduce_mean(axis=-1)`
# behaves like `tf.squeeze()` to return element-wise loss of shape (batch_size, ).
return torch.mean(weighted_input, dim=-1)


def kr_loss_multi_gpu(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""Returns the element-wise KR loss when computing with a multi-GPU/TPU strategy.
`target` and `input` can be either of shape (batch_size, 1) or
(batch_size, # classes).
When using this loss function, the labels `target` must be pre-processed with the
`process_labels_for_multi_gpu()` function.
Args:
input: Tensor of arbitrary shape.
target: pre-processed Tensor of the same shape as input.
Returns:
The Wasserstein-1 loss between ``input`` and ``target``.
"""
target = target.view(input.shape).to(input.dtype)
# Since the information of batch size was included in `target` by
# `process_labels_for_multi_gpu()`, there is no need here to multiply by batch size.
# In binary case (`target` of shape (batch_size, 1)), `torch.mean(dim=-1)`
# behaves like `torch.squeeze()` to return element-wise loss of shape (batch_size,)
return torch.mean(input * target, dim=-1)


def neg_kr_loss(
input: torch.Tensor,
target: torch.Tensor,
true_values: Tuple[int, int] = (0, 1),
multi_gpu=False,
) -> torch.Tensor:
"""
Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein
Expand All @@ -323,15 +396,15 @@ def neg_kr_loss(
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
true_values: Tuple containing the two label for the predicted classes.
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The negative Wasserstein-1 loss between ``input`` and ``target``.
See Also:
:py:func:`kr_loss`
"""
return -kr_loss(input, target, true_values)
return -kr_loss(input, target, multi_gpu=multi_gpu)


def hinge_margin_loss(
Expand All @@ -356,20 +429,17 @@ def hinge_margin_loss(
The hinge margin loss.
"""
target = target.view(input.shape)
return torch.mean(
torch.max(
torch.zeros_like(input),
min_margin - torch.sign(target) * input,
)
)
sign_target = torch.where(target > 0, 1.0, -1.0).to(input.dtype)
hinge = F.relu(min_margin / 2.0 - sign_target * input)
return torch.mean(hinge, dim=-1)


def hkr_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
min_margin: float = 1.0,
true_values: Tuple[int, int] = (-1, 1),
multi_gpu=False,
) -> torch.Tensor:
"""
Loss to estimate the wasserstein-1 distance with a hinge regularization using
Expand All @@ -378,9 +448,9 @@ def hkr_loss(
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
alpha: Regularization factor between the hinge and the KR loss.
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
min_margin: Minimal margin for the hinge loss.
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The regularized Wasserstein-1 loss.
Expand All @@ -389,37 +459,17 @@ def hkr_loss(
:py:func:`hinge_margin_loss`
:py:func:`kr_loss`
"""
if alpha == np.inf: # alpha negative hinge only
kr_loss_fct = kr_loss_multi_gpu if multi_gpu else kr_loss
assert alpha <= 1.0
if alpha == 1.0: # alpha for hinge only
return hinge_margin_loss(input, target, min_margin)

if alpha == 0:
return -kr_loss_fct(input, target)
# true value: positive value should be the first to be coherent with the
# hinge loss (positive y_pred)
return alpha * hinge_margin_loss(input, target, min_margin) - kr_loss(
input, target, (true_values[1], true_values[0])
)


def kr_multiclass_loss(
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
r"""
Loss to estimate average of W1 distance using Kantorovich-Rubinstein
duality over outputs. In this multiclass setup thr KR term is computed
for each class and then averaged.
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
target has to be one hot encoded (labels being 1s and 0s ).
Returns:
The Wasserstein multiclass loss between ``input`` and ``target``.
"""
esp_true_true = torch.sum(input * target, 0) / torch.sum(target, 0)
esp_false_true = torch.sum(input * (1 - target), 0) / torch.sum((1 - target), 0)

return torch.mean(esp_true_true - esp_false_true)
return alpha * hinge_margin_loss(input, target, min_margin) - (
1 - alpha
) * kr_loss_fct(input, target)


def hinge_multiclass_loss(
Expand All @@ -430,7 +480,7 @@ def hinge_multiclass_loss(
"""
Loss to estimate the Hinge loss in a multiclass setup. It compute the
elementwise hinge term. Note that this formulation differs from the
one commonly found in tensorflow/pytorch (with marximise the difference
one commonly found in tensorflow/pytorch (with maximise the difference
between the two largest logits). This formulation is consistent with the
binary classification loss used in a multiclass fashion.
Expand All @@ -445,17 +495,20 @@ def hinge_multiclass_loss(
Returns:
The hinge margin multiclass loss.
"""
return torch.mean(
((target.shape[-1] - 2) * target + 1)
* F.relu(min_margin - (2 * target - 1) * input)
)
sign_target = torch.where(target > 0, 1.0, -1.0).to(input.dtype)
hinge = F.relu(min_margin / 2.0 - sign_target * input)
# reweight positive elements
factor = target.shape[-1] - 1.0
hinge = torch.where(target > 0, hinge * factor, hinge)
return torch.mean(hinge, dim=-1)


def hkr_multiclass_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float = 0.0,
min_margin: float = 1.0,
multi_gpu=False,
) -> torch.Tensor:
"""
Loss to estimate the wasserstein-1 distance with a hinge regularization using
Expand All @@ -464,9 +517,9 @@ def hkr_multiclass_loss(
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
alpha: Regularization factor between the hinge and the KR loss.
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
min_margin: Minimal margin for the hinge loss.
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The regularized Wasserstein-1 loss.
Expand All @@ -476,11 +529,42 @@ def hkr_multiclass_loss(
:py:func:`kr_loss`
"""

if alpha == np.inf: # alpha negative hinge only
assert alpha <= 1.0
kr_loss_fct = kr_loss_multi_gpu if multi_gpu else kr_loss
if alpha == 1.0: # alpha hinge only
return hinge_multiclass_loss(input, target, min_margin)
elif alpha == 0.0: # alpha = 0 => KR only
return -kr_multiclass_loss(input, target)
return -kr_loss_fct(input, target)
else:
return -kr_multiclass_loss(input, target) + alpha * hinge_multiclass_loss(
input, target, min_margin
)
return alpha * hinge_multiclass_loss(input, target, min_margin) - (
1 - alpha
) * kr_loss_fct(input, target)


def process_labels_for_multi_gpu(labels: torch.Tensor) -> torch.Tensor:
"""Process labels to be fed to any loss based on KR estimation with a multi-GPU/TPU
strategy.
When using a multi-GPU/TPU strategy, the flag `multi_gpu` in KR-based losses must be
set to True and the labels have to be pre-processed with this function.
For binary classification, the labels should be of shape [batch_size, 1].
For multiclass problems, the labels must be one-hot encoded (1 or 0) with shape
[batch_size, number of classes].
Args:
labels (torch.Tensor): tensor containing the labels
Returns:
torch.Tensor: labels processed for KR-based losses with multi-GPU/TPU strategy.
"""
pos_labels = torch.where(labels > 0, 1.0, 0.0).to(labels.dtype)
mean_pos = torch.mean(pos_labels, dim=0)
# pos factor = batch_size/number of positive samples
pos_factor = torch.nan_to_num(1.0 / mean_pos)
# neg factor = batch_size/number of negative samples
neg_factor = -torch.nan_to_num(1.0 / (1.0 - mean_pos))

# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
return torch.where(labels > 0, pos_factor, neg_factor)
Loading

0 comments on commit 4360edf

Please sign in to comment.