Skip to content

Commit

Permalink
Merge pull request #294 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v0.9.98
  • Loading branch information
Kevin Musgrave authored Apr 3, 2021
2 parents 251bde4 + 883b4cb commit 987f2e9
Show file tree
Hide file tree
Showing 22 changed files with 548 additions and 377 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ conda install pytorch-metric-learning -c metric-learning -c pytorch
| [**SignalToNoiseRatioContrastiveLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#signaltonoiseratiocontrastiveloss) | [Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
| [**SoftTripleLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#softtripleloss) | [SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf)
| [**SphereFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#spherefaceloss) | [SphereFace: Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/pdf/1704.08063.pdf)
| [**SupConLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#supconloss) | [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362)
| [**TripletMarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#tripletmarginloss) | [Distance Metric Learning for Large Margin Nearest Neighbor Classification](https://papers.nips.cc/paper/2795-distance-metric-learning-for-large-margin-nearest-neighbor-classification.pdf)
| [**TupletMarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#tupletmarginloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf)

Expand Down Expand Up @@ -401,6 +402,8 @@ Thanks to the contributors who made pull requests!
- ```all_gather``` in [utils.distributed](https://kevinmusgrave.github.io/pytorch-metric-learning/distributed)
- [AlexSchuy](https://github.com/AlexSchuy)
- optimized ```utils.loss_and_miner_utils.get_random_triplet_indices```
- [fjsj](https://github.com/fjsj)
- [SupConLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#supconloss)

#### Example notebooks
- [wconnell](https://github.com/wconnell)
Expand Down
34 changes: 24 additions & 10 deletions docs/accuracy_calculation.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ AccuracyCalculator(include=(),
* **include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated.
* **exclude**: Optional. A list or tuple of strings, which are the names of metrics you **do not** want to calculate.
* **avg_of_avgs**: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned.
* **k**: If set, this number of nearest neighbors will be retrieved for metrics that require k-nearest neighbors. If None, the value of k will be determined as follows:
* First, count the number of occurrences of each label in ```reference_labels```, and set k to the maximum value. For example, if ```reference_labels``` is ```[0, 0, 1, 1, 1, 1, 2]```, then ```k = 4```.
* Then set ```k = min(1023, k)```. This is done because faiss (the library that computes k-nn) has a limit on the value of k.
* After k is set, the ```k+1``` nearest neighbors are found for every query sample. When the query and reference come from the same source, the 1st nearest neighbor is discarded since that "neighbor" is actually the query sample. When the query and reference come from different sources, the ```k+1``` neighbor is discarded.
* **k**: The number of nearest neighbors that will be retrieved for metrics that require k-nearest neighbors. The allowed values are:
* ```None```. This means k will be set to the total number of reference embeddings.
* ```x```, where ```x > 0```. This means k will be set to x.
* ```"max_bin_count"```. This means k will be set to ```max(bincount(reference_labels)) - self_count``` where ```self_count == 1``` if the query and reference embeddings come from the same source.
* **label_comparison_fn**: A function that compares two torch arrays of labels and returns a boolean array. The default is ```torch.eq```. If a custom function is used, then you must exclude clustering based metrics ("NMI" and "AMI"). The following is an example of a custom function for two-dimensional labels. It returns ```True``` if the 0th column matches, and the 1st column does **not** match:
```python
def example_label_comparison_fn(x, y):
Expand Down Expand Up @@ -60,6 +60,14 @@ def get_accuracy(self,
Note that labels can be 2D if a [custom label comparison function](#using-a-custom-label-comparison-function) is used.


### CPU/GPU usage

* If you installed ```faiss-cpu``` then the CPU will always be used.
* If you installed ```faiss-gpu```, then the GPU will be used if ```k <= 1024``` for CUDA < 9.5, and ```k <= 2048``` for CUDA >= 9.5. If this condition is not met, then the CPU will be used.

If your dataset is large, you might find the k-nn search is very slow. This is because the default behavior (```k = None```) is to search the entire dataset. To avoid this, you can set k to a number, like ```k = 1000```, or try ```k = "max_bin_count"```.


### Explanations of the default accuracy metrics

- **AMI**:
Expand Down Expand Up @@ -91,11 +99,7 @@ Note that labels can be 2D if a [custom label comparison function](#using-a-cust

**Important note**

AccuracyCalculator's ```mean_average_precision_at_r``` and ```r_precision``` are correct only if the following are true:

* every query class has less than 1024 samples in the reference set
* ```k = None```, **or** if you set ```k```, it must be ```>=``` the value that is automatically calculated (see the [parameters](#parameters) section)

AccuracyCalculator's ```mean_average_precision_at_r``` and ```r_precision``` are correct only if ```k = None```, **or** ```k = "max_bin_count"```, **or** ```k >= max(bincount(reference_labels))```


### Adding custom accuracy metrics
Expand Down Expand Up @@ -194,4 +198,14 @@ labels = torch.tensor([
0.04,
0.05,
])
```
```


### Warning for versions <= 0.9.97

The behavior of the ```k``` parameter described in the [Parameters](#parameters) section is for versions >= 0.9.98.

For versions <= 0.9.97, the behavior was:

* If ```k = None```, then ```k = min(1023, max(bincount(reference_labels)))```
* Otherwise ```k = k```
Binary file added docs/imgs/supcon_loss_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,33 @@ loss_optimizer.step()
* **loss**: The loss per element in the batch. Reduction type is ```"element"```.


## SupConLoss
Described in [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362){target=_blank}.
```python
losses.SupConLoss(temperature=0.1, **kwargs)
```

**Equation**:

![supcon_loss_equation](imgs/supcon_loss_equation.png){: style="height:90px"}

**Parameters**:

* **temperature**: This is tau in the above equation. The paper uses 0.1.

**Default distance**:

- [```CosineSimilarity()```](distances.md#cosinesimilarity)

**Default reducer**:

- [AvgNonZeroReducer](reducers.md#avgnonzeroreducer)

**Reducer input**:

* **loss**: The loss per element in the batch. If an element has only negative pairs or no pairs, it's ignored thanks to `AvgNonZeroReducer`. Reduction type is ```"element"```.


## TripletMarginLoss

```python
Expand Down
66 changes: 21 additions & 45 deletions examples/notebooks/CascadedEmbeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1015,8 +1015,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "-LzofXOCE71t",
"colab_type": "text"
"id": "-LzofXOCE71t"
},
"source": [
"# PyTorch Metric Learning\n",
Expand All @@ -1027,8 +1026,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "MKpRHvy24tV7",
"colab_type": "text"
"id": "MKpRHvy24tV7"
},
"source": [
"## Install the necessary packages"
Expand All @@ -1038,15 +1036,15 @@
"cell_type": "code",
"metadata": {
"id": "ZeIGxbbp3W2S",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 87
},
"outputId": "8ce9b687-a7e5-4880-be70-3df801f7893e"
},
"source": [
"!pip install -q pytorch-metric-learning[with-hooks]"
"!pip install -q pytorch-metric-learning[with-hooks]\n",
"!pip install umap-learn"
],
"execution_count": null,
"outputs": [
Expand All @@ -1065,8 +1063,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "BfqRRbIw4zYR",
"colab_type": "text"
"id": "BfqRRbIw4zYR"
},
"source": [
"## Import the packages"
Expand All @@ -1076,7 +1073,6 @@
"cell_type": "code",
"metadata": {
"id": "567qnmi7wk_M",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
Expand All @@ -1088,6 +1084,7 @@
"from pytorch_metric_learning import losses, miners, samplers, trainers, testers\n",
"from pytorch_metric_learning.utils import common_functions\n",
"import pytorch_metric_learning.utils.logging_presets as logging_presets\n",
"from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n",
"import numpy as np\n",
"import torchvision\n",
"from torchvision import datasets, transforms\n",
Expand Down Expand Up @@ -1118,8 +1115,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "Qxs6EEeR496q",
"colab_type": "text"
"id": "Qxs6EEeR496q"
},
"source": [
"## Model defs"
Expand All @@ -1128,9 +1124,7 @@
{
"cell_type": "code",
"metadata": {
"id": "zKyR6gnTwk_P",
"colab_type": "code",
"colab": {}
"id": "zKyR6gnTwk_P"
},
"source": [
"class MLP(nn.Module):\n",
Expand Down Expand Up @@ -1184,8 +1178,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "btjxk6zR5Cl6",
"colab_type": "text"
"id": "btjxk6zR5Cl6"
},
"source": [
"## Initialize models, optimizers and image transforms"
Expand All @@ -1195,7 +1188,6 @@
"cell_type": "code",
"metadata": {
"id": "8tzmyFS3wk_R",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 217,
Expand Down Expand Up @@ -1374,8 +1366,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "Xf0xgdWS5GqG",
"colab_type": "text"
"id": "Xf0xgdWS5GqG"
},
"source": [
"## Create the dataset and class-disjoint train/val splits"
Expand All @@ -1385,7 +1376,6 @@
"cell_type": "code",
"metadata": {
"id": "D-nmnYYAwk_T",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 101,
Expand Down Expand Up @@ -1470,8 +1460,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "r7J817Vs5LNs",
"colab_type": "text"
"id": "r7J817Vs5LNs"
},
"source": [
"##Create the loss, miner, sampler, and package them into dictionaries\n"
Expand All @@ -1481,9 +1470,7 @@
"cell_type": "code",
"metadata": {
"scrolled": false,
"id": "Kp9AC_4Dwk_V",
"colab_type": "code",
"colab": {}
"id": "Kp9AC_4Dwk_V"
},
"source": [
"# Set the loss functions. loss0 will be applied to the first embedder, loss1 to the second embedder etc.\n",
Expand Down Expand Up @@ -1514,9 +1501,7 @@
{
"cell_type": "code",
"metadata": {
"id": "23ddQ3anM-sO",
"colab_type": "code",
"colab": {}
"id": "23ddQ3anM-sO"
},
"source": [
"# Remove logs if you want to train with new parameters\n",
Expand All @@ -1528,8 +1513,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "A9fa1VYD5Yv0",
"colab_type": "text"
"id": "A9fa1VYD5Yv0"
},
"source": [
"## Create the training and testing hooks"
Expand All @@ -1538,9 +1522,7 @@
{
"cell_type": "code",
"metadata": {
"id": "Vq_Pd7Pd5Xi_",
"colab_type": "code",
"colab": {}
"id": "Vq_Pd7Pd5Xi_"
},
"source": [
"record_keeper, _, _ = logging_presets.get_record_keeper(\"example_logs\", \"example_tensorboard\")\n",
Expand All @@ -1563,7 +1545,8 @@
"tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, \n",
" visualizer = umap.UMAP(), \n",
" visualizer_hook = visualizer_hook,\n",
" dataloader_num_workers = 32)\n",
" dataloader_num_workers = 32,\n",
" accuracy_calculator=AccuracyCalculator(k=\"max_bin_count\"))\n",
"\n",
"end_of_epoch_hook = hooks.end_of_epoch_hook(tester, \n",
" dataset_dict, \n",
Expand All @@ -1577,8 +1560,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "A0D3Jvxc5iWD",
"colab_type": "text"
"id": "A0D3Jvxc5iWD"
},
"source": [
"## Create the trainer"
Expand All @@ -1589,7 +1571,6 @@
"metadata": {
"scrolled": false,
"id": "DuASrVs-wk_X",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
Expand Down Expand Up @@ -1627,8 +1608,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "jbt_j2TzQ669",
"colab_type": "text"
"id": "jbt_j2TzQ669"
},
"source": [
"## Start Tensorboard\n",
Expand All @@ -1638,9 +1618,7 @@
{
"cell_type": "code",
"metadata": {
"id": "s6HCNGf2Q6_X",
"colab_type": "code",
"colab": {}
"id": "s6HCNGf2Q6_X"
},
"source": [
"%load_ext tensorboard\n",
Expand All @@ -1652,8 +1630,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "gIq7s7jf5ksj",
"colab_type": "text"
"id": "gIq7s7jf5ksj"
},
"source": [
"## Train the model"
Expand All @@ -1663,7 +1640,6 @@
"cell_type": "code",
"metadata": {
"id": "WHza2JJHwk_Z",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
Expand Down
Loading

0 comments on commit 987f2e9

Please sign in to comment.