From ef1bd0627633ab2fb3c8a2a1e58938de866f5b2c Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 24 Jul 2024 11:43:27 +0000 Subject: [PATCH] formatted code --- src/pytorch_metric_learning/utils/distributed.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/utils/distributed.py b/src/pytorch_metric_learning/utils/distributed.py index ab205f41..40dddf90 100644 --- a/src/pytorch_metric_learning/utils/distributed.py +++ b/src/pytorch_metric_learning/utils/distributed.py @@ -109,7 +109,14 @@ def forward( return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels) world_size = torch.distributed.get_world_size() - common_args = [embeddings, labels, indices_tuple, ref_emb, ref_labels, world_size] + common_args = [ + embeddings, + labels, + indices_tuple, + ref_emb, + ref_labels, + world_size, + ] if isinstance(self.loss, CrossBatchMemory): return self.forward_cross_batch(*common_args, enqueue_mask) return self.forward_regular_loss(*common_args)