Skip to content

Commit

Permalink
formatted code
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinMusgrave committed Jul 24, 2024
1 parent 0db2483 commit ef1bd06
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,14 @@ def forward(
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)

world_size = torch.distributed.get_world_size()
common_args = [embeddings, labels, indices_tuple, ref_emb, ref_labels, world_size]
common_args = [
embeddings,
labels,
indices_tuple,
ref_emb,
ref_labels,
world_size,
]
if isinstance(self.loss, CrossBatchMemory):
return self.forward_cross_batch(*common_args, enqueue_mask)
return self.forward_regular_loss(*common_args)
Expand Down

0 comments on commit ef1bd06

Please sign in to comment.