Skip to content

Commit

Permalink
Merge pull request #692 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.5.0
  • Loading branch information
Kevin Musgrave authored Apr 1, 2024
2 parents 3a14f82 + ef65345 commit 5e5319d
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.4.1"
__version__ = "2.5.0"
18 changes: 9 additions & 9 deletions src/pytorch_metric_learning/losses/manifold_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
if self.lambdaC != np.inf:
F = F[:N, N:]
loss_int = F - F[torch.arange(N), meta_classes].view(-1, 1) + self.margin
loss_int[
torch.arange(N), meta_classes
] = -np.inf # This way avoid numerical cancellation happening # NoQA
loss_int[torch.arange(N), meta_classes] = (
-np.inf
) # This way avoid numerical cancellation happening # NoQA
# instead with subtraction of margin term # NoQA
loss_int[
loss_int < 0
] = -np.inf # This way no loss for positive correlation with own proxy
loss_int[loss_int < 0] = (
-np.inf
) # This way no loss for positive correlation with own proxy

loss_int = torch.exp(loss_int)
loss_int = torch.log(1 + torch.sum(loss_int, dim=1))
Expand All @@ -106,9 +106,9 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
F_e, F_p.unsqueeze(1), dim=-1
).t()
loss_ctx += -loss_ctx[torch.arange(N), meta_classes].view(-1, 1) + self.margin
loss_ctx[
torch.arange(N), meta_classes
] = -np.inf # This way avoid numerical cancellation happening # NoQA
loss_ctx[torch.arange(N), meta_classes] = (
-np.inf
) # This way avoid numerical cancellation happening # NoQA
# instead with subtraction of margin term # NoQA
loss_ctx[loss_ctx < 0] = -np.inf

Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_metric_learning/testers/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,10 @@ def test(
query_split_name,
reference_split_names,
)
self.end_of_testing_hook(self) if self.end_of_testing_hook else c_f.LOGGER.info(
self.all_accuracies
(
self.end_of_testing_hook(self)
if self.end_of_testing_hook
else c_f.LOGGER.info(self.all_accuracies)
)
del self.embeddings_and_labels
return self.all_accuracies
50 changes: 48 additions & 2 deletions src/pytorch_metric_learning/utils/loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,57 @@ def neg_pairs_from_tuple(indices_tuple):


def get_all_triplets_indices(labels, ref_labels=None):
matches, diffs = get_matches_and_diffs(labels, ref_labels)
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)

if (
all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
< torch.iinfo(torch.int32).max
):
# torch.nonzero is not supported for tensors with more than INT_MAX elements
return get_all_triplets_indices_vectorized_method(all_matches, all_diffs)

return get_all_triplets_indices_loop_method(labels, all_matches, all_diffs)


def get_all_triplets_indices_vectorized_method(all_matches, all_diffs):
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
return torch.where(triplets)


def get_all_triplets_indices_loop_method(labels, all_matches, all_diffs):
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()

# Find anchors with at least a positive and a negative
indices = torch.arange(0, len(labels), device=labels.device)
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]

# No triplets found
if len(indices) == 0:
return (
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
)

# Compute all triplets
anchors = []
positives = []
negatives = []
for i in indices:
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
nd = len(diffs)
nm = len(matches)
matches = matches.repeat_interleave(nd)
diffs = diffs.repeat(nm)
anchors.append(
torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device)
)
positives.append(matches)
negatives.append(diffs)
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)


# sample triplets, with a weighted distribution if weights is specified.
def get_random_triplet_indices(
labels, ref_labels=None, t_per_anchor=None, weights=None
Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_calculate_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def test_accuracy_calculator(self):
"query_labels": query_labels,
"label_counts": label_counts,
"knn_labels": knn_labels,
"not_lone_query_mask": torch.ones(6, dtype=torch.bool)
if i == 0
else torch.zeros(6, dtype=torch.bool),
"not_lone_query_mask": (
torch.ones(6, dtype=torch.bool)
if i == 0
else torch.zeros(6, dtype=torch.bool)
),
}

function_dict = AC.get_function_dict()
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,29 @@ def test_remove_self_comparisons_small_ref(self):
self.assertTrue(torch.equal(a1, correct_a1))
self.assertTrue(torch.equal(p, correct_p))

def test_get_all_triplets_indices(self):
torch.manual_seed(920)
for dtype in TEST_DTYPES:
for batch_size in [32, 256, 512]:
for ref_labels in [None, torch.randint(0, 5, size=(batch_size // 2,))]:
labels = torch.randint(0, 5, size=(batch_size,))

a, p, n = lmu.get_all_triplets_indices(labels, ref_labels)
matches, diffs = lmu.get_matches_and_diffs(labels, ref_labels)

a2, p2, n2 = lmu.get_all_triplets_indices_vectorized_method(
matches, diffs
)
a3, p3, n3 = lmu.get_all_triplets_indices_loop_method(
labels, matches, diffs
)
self.assertTrue(
(a == a2).all() and (p == p2).all() and (n == n2).all()
)
self.assertTrue(
(a == a3).all() and (p == p3).all() and (n == n3).all()
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5e5319d

Please sign in to comment.