Skip to content

Commit

Permalink
Merge pull request #702 from elisim/master
Browse files Browse the repository at this point in the history
DistributedLossWrapper enhancements
  • Loading branch information
Kevin Musgrave authored Jul 23, 2024
2 parents ef65345 + ed2c1e2 commit 0db2483
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/base_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
pytorch-version: 1.6
torchvision-version: 0.7
- python-version: 3.9
pytorch-version: 2.1
torchvision-version: 0.16
pytorch-version: 2.3
torchvision-version: 0.18

steps:
- uses: actions/checkout@v2
Expand All @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
pip install .[with-hooks-cpu]
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install --upgrade protobuf==3.20.1
pip install six
pip install packaging
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

## News

**April 1**: v2.5.0
- Improved `get_all_triplets_indices` so that large batch sizes don't trigger the `INT_MAX` error.
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.5.0).
- Thank you [mkmenta](https://github.com/mkmenta).

**December 15**: v2.4.0
- Added [DynamicSoftMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#dynamicsoftmarginloss).
- Added [RankedListLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#rankedlistloss).
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.4.0).
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0), [Puzer](https://github.com/Puzer), [interestingzhuo](https://github.com/interestingzhuo), and [GaetanLepage](https://github.com/GaetanLepage).

**July 25**: v2.3.0
- Added [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).

## Documentation
- [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/)
- [**View the installation instructions here**](https://github.com/KevinMusgrave/pytorch-metric-learning#installation)
Expand Down Expand Up @@ -236,6 +237,7 @@ Thanks to the contributors who made pull requests!
| [AlenUbuntu](https://github.com/AlenUbuntu) | [CircleLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#circleloss) |
| [interestingzhuo](https://github.com/interestingzhuo) | [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) |
| [wconnell](https://github.com/wconnell) | [Learning a scRNAseq Metric Embedding](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/scRNAseq_MetricEmbedding.ipynb) |
| [mkmenta](https://github.com/mkmenta) | Improved `get_all_triplets_indices` (fixed the `INT_MAX` error) |
| [AlexSchuy](https://github.com/AlexSchuy) | optimized ```utils.loss_and_miner_utils.get_random_triplet_indices``` |
| [JohnGiorgi](https://github.com/JohnGiorgi) | ```all_gather``` in [utils.distributed](https://kevinmusgrave.github.io/pytorch-metric-learning/distributed) |
| [Hummer12007](https://github.com/Hummer12007) | ```utils.key_checker``` |
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
],
python_requires=">=3.0",
install_requires=[
"numpy",
"numpy < 2.0",
"scikit-learn",
"tqdm",
"torch >= 1.6.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.5.0"
__version__ = "2.6.0"
12 changes: 10 additions & 2 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch

from ..losses import BaseMetricLossFunction, CrossBatchMemory
Expand Down Expand Up @@ -93,15 +95,21 @@ def __init__(self, loss, efficient=False):

def forward(
self,
emb,
embeddings,
labels=None,
indices_tuple=None,
ref_emb=None,
ref_labels=None,
enqueue_mask=None,
):
if not is_distributed():
warnings.warn(
"DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is."
)
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)

world_size = torch.distributed.get_world_size()
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
common_args = [embeddings, labels, indices_tuple, ref_emb, ref_labels, world_size]
if isinstance(self.loss, CrossBatchMemory):
return self.forward_cross_batch(*common_args, enqueue_mask)
return self.forward_regular_loss(*common_args)
Expand Down

0 comments on commit 0db2483

Please sign in to comment.