Skip to content

Commit

Permalink
Merge pull request #605 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.1.0
  • Loading branch information
Kevin Musgrave authored Apr 5, 2023
2 parents 691a635 + d94576c commit 03d5192
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 1 deletion.
1 change: 1 addition & 0 deletions CONTENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
| [**NormalizedSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#normalizedsoftmaxloss) | - [NormFace: L2 Hypersphere Embedding for Face Verification](https://arxiv.org/pdf/1704.06369.pdf) <br/> - [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/pdf/1811.12649.pdf)
| [**NPairsLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#npairsloss) | [Improved Deep Metric Learning with Multi-class N-pair Loss Objective](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
| [**NTXentLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss) | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf) <br/> - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf) <br/> - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709)
| [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) | [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf)
| [**ProxyAnchorLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyanchorloss) | [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf)
| [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf)
| [**SignalToNoiseRatioContrastiveLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#signaltonoiseratiocontrastiveloss) | [Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ Thanks to the contributors who made pull requests!
| [elias-ramzi](https://github.com/elias-ramzi) | [HierarchicalSampler](https://kevinmusgrave.github.io/pytorch-metric-learning/samplers/#hierarchicalsampler) |
| [fjsj](https://github.com/fjsj) | [SupConLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#supconloss) |
| [AlenUbuntu](https://github.com/AlenUbuntu) | [CircleLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#circleloss) |
| [interestingzhuo](https://github.com/interestingzhuo) | [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) |
| [wconnell](https://github.com/wconnell) | [Learning a scRNAseq Metric Embedding](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/scRNAseq_MetricEmbedding.ipynb) |
| [AlexSchuy](https://github.com/AlexSchuy) | optimized ```utils.loss_and_miner_utils.get_random_triplet_indices``` |
| [JohnGiorgi](https://github.com/JohnGiorgi) | ```all_gather``` in [utils.distributed](https://kevinmusgrave.github.io/pytorch-metric-learning/distributed) |
Expand Down
8 changes: 8 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,14 @@ losses.NTXentLoss(temperature=0.07, **kwargs)

* **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.


## PNPLoss
[Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf){target=_blank}
```python
losses.PNPLoss(b=2, alpha=1, anneal=0.01, variant="O", **kwargs)
```


## ProxyAnchorLoss
[Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf){target=_blank}
```python
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.0.1"
__version__ = "2.1.0"
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .nca_loss import NCALoss
from .normalized_softmax_loss import NormalizedSoftmaxLoss
from .ntxent_loss import NTXentLoss
from .pnp_loss import PNPLoss
from .proxy_anchor_loss import ProxyAnchorLoss
from .proxy_losses import ProxyNCALoss
from .self_supervised_loss import SelfSupervisedLoss
Expand Down
95 changes: 95 additions & 0 deletions src/pytorch_metric_learning/losses/pnp_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class PNPLoss(BaseMetricLossFunction):
VARIANTS = ["Ds", "Dq", "Iu", "Ib", "O"]

def __init__(self, b=2, alpha=1, anneal=0.01, variant="O", **kwargs):
super().__init__(**kwargs)
c_f.assert_distance_type(self, CosineSimilarity)
self.b = b
self.alpha = alpha
self.anneal = anneal
self.variant = variant
if self.variant not in self.VARIANTS:
raise ValueError(f"variant={variant} but must be one of {self.VARIANTS}")

"""
Adapted from https://github.com/interestingzhuo/PNPloss
"""

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
c_f.indices_tuple_not_supported(indices_tuple)
c_f.labels_required(labels)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
dtype, device = embeddings.dtype, embeddings.device

N = labels.size(0)
a1_idx, p_idx, a2_idx, n_idx = lmu.get_all_pairs_indices(labels)
I_pos = torch.zeros(N, N, dtype=dtype, device=device)
I_neg = torch.zeros(N, N, dtype=dtype, device=device)
I_pos[a1_idx, p_idx] = 1
I_pos[a1_idx, a1_idx] = 1
I_neg[a2_idx, n_idx] = 1
N_pos = torch.sum(I_pos, dim=1)
safe_N = N_pos > 0
if torch.sum(safe_N) == 0:
return self.zero_losses()
sim_all = self.distance(embeddings)

mask = I_neg.unsqueeze(dim=0).repeat(N, 1, 1)

sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, N, 1)
# compute the difference matrix
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
# pass through the sigmoid and ignores the relevance score of the query to itself
sim_sg = self.sigmoid(sim_diff, temp=self.anneal) * mask
# compute the number of negatives before
sim_all_rk = torch.sum(sim_sg, dim=-1)

if self.variant == "Ds":
sim_all_rk = torch.log(1 + sim_all_rk)
elif self.variant == "Dq":
sim_all_rk = 1 / (1 + sim_all_rk) ** (self.alpha)

elif self.variant == "Iu":
sim_all_rk = (1 + sim_all_rk) * torch.log(1 + sim_all_rk)

elif self.variant == "Ib":
b = self.b
sim_all_rk = 1 / b**2 * (b * sim_all_rk - torch.log(1 + b * sim_all_rk))
elif self.variant == "O":
pass
else:
raise Exception(f"variant <{self.variant}> not available!")

loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1)
loss = torch.sum(loss) / N
if self.variant == "Dq":
loss = 1 - loss

return {
"loss": {
"losses": loss,
"indices": torch.where(safe_N)[0],
"reduction_type": "already_reduced",
}
}

def sigmoid(self, tensor, temp=1.0):
"""temperature controlled sigmoid
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
"""
exponent = -tensor / temp
# clamp the input tensor for stability
exponent = torch.clamp(exponent, min=-50, max=50)
y = 1.0 / (1.0 + torch.exp(exponent))
return y

def get_default_distance(self):
return CosineSimilarity()
133 changes: 133 additions & 0 deletions tests/losses/test_pnp_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import unittest

import torch
import torch.nn
import torch.nn.functional
import torch.nn.functional as F

from pytorch_metric_learning.losses import PNPLoss

from .. import TEST_DEVICE, TEST_DTYPES


class OriginalImplementationPNP(torch.nn.Module):
def __init__(self, b, alpha, anneal, variant, bs, classes):
super(OriginalImplementationPNP, self).__init__()
self.b = b
self.alpha = alpha
self.anneal = anneal
self.variant = variant
self.batch_size = bs
self.num_id = classes
self.samples_per_class = int(bs / classes)

mask = 1.0 - torch.eye(self.batch_size)
for i in range(self.num_id):
mask[
i * (self.samples_per_class) : (i + 1) * (self.samples_per_class),
i * (self.samples_per_class) : (i + 1) * (self.samples_per_class),
] = 0

self.mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)

def forward(self, batch):

dtype, device = batch.dtype, batch.device
self.mask = self.mask.type(dtype).to(device)
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors

sim_all = self.compute_aff(batch)

sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
# compute the difference matrix
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
# pass through the sigmoid and ignores the relevance score of the query to itself
sim_sg = self.sigmoid(sim_diff, temp=self.anneal) * self.mask
# compute the rankings,all batch
sim_all_rk = torch.sum(sim_sg, dim=-1)
if self.variant == "PNP-D_s":
sim_all_rk = torch.log(1 + sim_all_rk)
elif self.variant == "PNP-D_q":
sim_all_rk = 1 / (1 + sim_all_rk) ** (self.alpha)

elif self.variant == "PNP-I_u":
sim_all_rk = (1 + sim_all_rk) * torch.log(1 + sim_all_rk)

elif self.variant == "PNP-I_b":
b = self.b
sim_all_rk = 1 / b**2 * (b * sim_all_rk - torch.log(1 + b * sim_all_rk))
elif self.variant == "PNP-O":
pass
else:
raise Exception("variantation <{}> not available!".format(self.variant))

# sum the values of the Smooth-AP for all instances in the mini-batch
loss = torch.zeros(1).type(dtype).to(device)
group = int(self.batch_size / self.num_id)

for ind in range(self.num_id):
neg_divide = torch.sum(
sim_all_rk[
(ind * group) : ((ind + 1) * group),
(ind * group) : ((ind + 1) * group),
]
/ group
)
loss = loss + (neg_divide / self.batch_size)
if self.variant == "PNP-D_q":
return 1 - loss
else:
return loss

def sigmoid(self, tensor, temp=1.0):
"""temperature controlled sigmoid
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
"""
exponent = -tensor / temp
# clamp the input tensor for stability
exponent = torch.clamp(exponent, min=-50, max=50)
y = 1.0 / (1.0 + torch.exp(exponent))
return y

def compute_aff(self, x):
"""computes the affinity matrix between an input vector and itself"""
return torch.mm(x, x.t())


class TestPNPLoss(unittest.TestCase):
def test_pnp_loss(self):
torch.manual_seed(30293)
bs = 180
classes = 30
for variant in PNPLoss.VARIANTS:
original_variant = {
"Ds": "PNP-D_s",
"Dq": "PNP-D_q",
"Iu": "PNP-I_u",
"Ib": "PNP-I_b",
"O": "PNP-O",
}[variant]
b, alpha, anneal = 2, 4, 0.01
loss_func = PNPLoss(b, alpha, anneal, variant)
original_loss_func = OriginalImplementationPNP(
b, alpha, anneal, original_variant, bs, classes
).to(TEST_DEVICE)

for dtype in TEST_DTYPES:
embeddings = torch.randn(
180, 32, dtype=dtype, device=TEST_DEVICE, requires_grad=True
)
labels = (
torch.tensor([[i] * (int(bs / classes)) for i in range(classes)])
.reshape(-1)
.to(TEST_DEVICE)
)
loss = loss_func(embeddings, labels)
loss.backward()
correct_loss = original_loss_func(F.normalize(embeddings, dim=-1))

rtol = 1e-2 if dtype == torch.float16 else 1e-5
self.assertTrue(torch.isclose(loss, correct_loss[0], rtol=rtol))

with self.assertRaises(ValueError):
PNPLoss(b, alpha, anneal, "PNP")

0 comments on commit 03d5192

Please sign in to comment.