diff --git a/README.md b/README.md index 4aa8b344..0278de26 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,7 @@ conda install pytorch-metric-learning -c metric-learning -c pytorch | [**SignalToNoiseRatioContrastiveLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#signaltonoiseratiocontrastiveloss) | [Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) | [**SoftTripleLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#softtripleloss) | [SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf) | [**SphereFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#spherefaceloss) | [SphereFace: Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/pdf/1704.08063.pdf) +| [**SupConLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#supconloss) | [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362) | [**TripletMarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#tripletmarginloss) | [Distance Metric Learning for Large Margin Nearest Neighbor Classification](https://papers.nips.cc/paper/2795-distance-metric-learning-for-large-margin-nearest-neighbor-classification.pdf) | [**TupletMarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#tupletmarginloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf) @@ -401,6 +402,8 @@ Thanks to the contributors who made pull requests! - ```all_gather``` in [utils.distributed](https://kevinmusgrave.github.io/pytorch-metric-learning/distributed) - [AlexSchuy](https://github.com/AlexSchuy) - optimized ```utils.loss_and_miner_utils.get_random_triplet_indices``` +- [fjsj](https://github.com/fjsj) + - [SupConLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#supconloss) #### Example notebooks - [wconnell](https://github.com/wconnell) diff --git a/docs/accuracy_calculation.md b/docs/accuracy_calculation.md index 4af86ad5..44b83faf 100644 --- a/docs/accuracy_calculation.md +++ b/docs/accuracy_calculation.md @@ -15,10 +15,10 @@ AccuracyCalculator(include=(), * **include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated. * **exclude**: Optional. A list or tuple of strings, which are the names of metrics you **do not** want to calculate. * **avg_of_avgs**: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned. -* **k**: If set, this number of nearest neighbors will be retrieved for metrics that require k-nearest neighbors. If None, the value of k will be determined as follows: - * First, count the number of occurrences of each label in ```reference_labels```, and set k to the maximum value. For example, if ```reference_labels``` is ```[0, 0, 1, 1, 1, 1, 2]```, then ```k = 4```. - * Then set ```k = min(1023, k)```. This is done because faiss (the library that computes k-nn) has a limit on the value of k. - * After k is set, the ```k+1``` nearest neighbors are found for every query sample. When the query and reference come from the same source, the 1st nearest neighbor is discarded since that "neighbor" is actually the query sample. When the query and reference come from different sources, the ```k+1``` neighbor is discarded. +* **k**: The number of nearest neighbors that will be retrieved for metrics that require k-nearest neighbors. The allowed values are: + * ```None```. This means k will be set to the total number of reference embeddings. + * ```x```, where ```x > 0```. This means k will be set to x. + * ```"max_bin_count"```. This means k will be set to ```max(bincount(reference_labels)) - self_count``` where ```self_count == 1``` if the query and reference embeddings come from the same source. * **label_comparison_fn**: A function that compares two torch arrays of labels and returns a boolean array. The default is ```torch.eq```. If a custom function is used, then you must exclude clustering based metrics ("NMI" and "AMI"). The following is an example of a custom function for two-dimensional labels. It returns ```True``` if the 0th column matches, and the 1st column does **not** match: ```python def example_label_comparison_fn(x, y): @@ -60,6 +60,14 @@ def get_accuracy(self, Note that labels can be 2D if a [custom label comparison function](#using-a-custom-label-comparison-function) is used. +### CPU/GPU usage + +* If you installed ```faiss-cpu``` then the CPU will always be used. +* If you installed ```faiss-gpu```, then the GPU will be used if ```k <= 1024``` for CUDA < 9.5, and ```k <= 2048``` for CUDA >= 9.5. If this condition is not met, then the CPU will be used. + +If your dataset is large, you might find the k-nn search is very slow. This is because the default behavior (```k = None```) is to search the entire dataset. To avoid this, you can set k to a number, like ```k = 1000```, or try ```k = "max_bin_count"```. + + ### Explanations of the default accuracy metrics - **AMI**: @@ -91,11 +99,7 @@ Note that labels can be 2D if a [custom label comparison function](#using-a-cust **Important note** -AccuracyCalculator's ```mean_average_precision_at_r``` and ```r_precision``` are correct only if the following are true: - -* every query class has less than 1024 samples in the reference set -* ```k = None```, **or** if you set ```k```, it must be ```>=``` the value that is automatically calculated (see the [parameters](#parameters) section) - +AccuracyCalculator's ```mean_average_precision_at_r``` and ```r_precision``` are correct only if ```k = None```, **or** ```k = "max_bin_count"```, **or** ```k >= max(bincount(reference_labels))``` ### Adding custom accuracy metrics @@ -194,4 +198,14 @@ labels = torch.tensor([ 0.04, 0.05, ]) -``` \ No newline at end of file +``` + + +### Warning for versions <= 0.9.97 + +The behavior of the ```k``` parameter described in the [Parameters](#parameters) section is for versions >= 0.9.98. + +For versions <= 0.9.97, the behavior was: + +* If ```k = None```, then ```k = min(1023, max(bincount(reference_labels)))``` +* Otherwise ```k = k``` \ No newline at end of file diff --git a/docs/imgs/supcon_loss_equation.png b/docs/imgs/supcon_loss_equation.png new file mode 100644 index 00000000..78a08041 Binary files /dev/null and b/docs/imgs/supcon_loss_equation.png differ diff --git a/docs/losses.md b/docs/losses.md index c4ef6c92..3d543e7d 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -893,6 +893,33 @@ loss_optimizer.step() * **loss**: The loss per element in the batch. Reduction type is ```"element"```. +## SupConLoss +Described in [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362){target=_blank}. +```python +losses.SupConLoss(temperature=0.1, **kwargs) +``` + +**Equation**: + +![supcon_loss_equation](imgs/supcon_loss_equation.png){: style="height:90px"} + +**Parameters**: + +* **temperature**: This is tau in the above equation. The paper uses 0.1. + +**Default distance**: + + - [```CosineSimilarity()```](distances.md#cosinesimilarity) + +**Default reducer**: + +- [AvgNonZeroReducer](reducers.md#avgnonzeroreducer) + +**Reducer input**: + +* **loss**: The loss per element in the batch. If an element has only negative pairs or no pairs, it's ignored thanks to `AvgNonZeroReducer`. Reduction type is ```"element"```. + + ## TripletMarginLoss ```python diff --git a/examples/notebooks/CascadedEmbeddings.ipynb b/examples/notebooks/CascadedEmbeddings.ipynb index c50598fe..d50e5b51 100644 --- a/examples/notebooks/CascadedEmbeddings.ipynb +++ b/examples/notebooks/CascadedEmbeddings.ipynb @@ -1015,8 +1015,7 @@ { "cell_type": "markdown", "metadata": { - "id": "-LzofXOCE71t", - "colab_type": "text" + "id": "-LzofXOCE71t" }, "source": [ "# PyTorch Metric Learning\n", @@ -1027,8 +1026,7 @@ { "cell_type": "markdown", "metadata": { - "id": "MKpRHvy24tV7", - "colab_type": "text" + "id": "MKpRHvy24tV7" }, "source": [ "## Install the necessary packages" @@ -1038,7 +1036,6 @@ "cell_type": "code", "metadata": { "id": "ZeIGxbbp3W2S", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 87 @@ -1046,7 +1043,8 @@ "outputId": "8ce9b687-a7e5-4880-be70-3df801f7893e" }, "source": [ - "!pip install -q pytorch-metric-learning[with-hooks]" + "!pip install -q pytorch-metric-learning[with-hooks]\n", + "!pip install umap-learn" ], "execution_count": null, "outputs": [ @@ -1065,8 +1063,7 @@ { "cell_type": "markdown", "metadata": { - "id": "BfqRRbIw4zYR", - "colab_type": "text" + "id": "BfqRRbIw4zYR" }, "source": [ "## Import the packages" @@ -1076,7 +1073,6 @@ "cell_type": "code", "metadata": { "id": "567qnmi7wk_M", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -1088,6 +1084,7 @@ "from pytorch_metric_learning import losses, miners, samplers, trainers, testers\n", "from pytorch_metric_learning.utils import common_functions\n", "import pytorch_metric_learning.utils.logging_presets as logging_presets\n", + "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n", "import numpy as np\n", "import torchvision\n", "from torchvision import datasets, transforms\n", @@ -1118,8 +1115,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Qxs6EEeR496q", - "colab_type": "text" + "id": "Qxs6EEeR496q" }, "source": [ "## Model defs" @@ -1128,9 +1124,7 @@ { "cell_type": "code", "metadata": { - "id": "zKyR6gnTwk_P", - "colab_type": "code", - "colab": {} + "id": "zKyR6gnTwk_P" }, "source": [ "class MLP(nn.Module):\n", @@ -1184,8 +1178,7 @@ { "cell_type": "markdown", "metadata": { - "id": "btjxk6zR5Cl6", - "colab_type": "text" + "id": "btjxk6zR5Cl6" }, "source": [ "## Initialize models, optimizers and image transforms" @@ -1195,7 +1188,6 @@ "cell_type": "code", "metadata": { "id": "8tzmyFS3wk_R", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 217, @@ -1374,8 +1366,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Xf0xgdWS5GqG", - "colab_type": "text" + "id": "Xf0xgdWS5GqG" }, "source": [ "## Create the dataset and class-disjoint train/val splits" @@ -1385,7 +1376,6 @@ "cell_type": "code", "metadata": { "id": "D-nmnYYAwk_T", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 101, @@ -1470,8 +1460,7 @@ { "cell_type": "markdown", "metadata": { - "id": "r7J817Vs5LNs", - "colab_type": "text" + "id": "r7J817Vs5LNs" }, "source": [ "##Create the loss, miner, sampler, and package them into dictionaries\n" @@ -1481,9 +1470,7 @@ "cell_type": "code", "metadata": { "scrolled": false, - "id": "Kp9AC_4Dwk_V", - "colab_type": "code", - "colab": {} + "id": "Kp9AC_4Dwk_V" }, "source": [ "# Set the loss functions. loss0 will be applied to the first embedder, loss1 to the second embedder etc.\n", @@ -1514,9 +1501,7 @@ { "cell_type": "code", "metadata": { - "id": "23ddQ3anM-sO", - "colab_type": "code", - "colab": {} + "id": "23ddQ3anM-sO" }, "source": [ "# Remove logs if you want to train with new parameters\n", @@ -1528,8 +1513,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A9fa1VYD5Yv0", - "colab_type": "text" + "id": "A9fa1VYD5Yv0" }, "source": [ "## Create the training and testing hooks" @@ -1538,9 +1522,7 @@ { "cell_type": "code", "metadata": { - "id": "Vq_Pd7Pd5Xi_", - "colab_type": "code", - "colab": {} + "id": "Vq_Pd7Pd5Xi_" }, "source": [ "record_keeper, _, _ = logging_presets.get_record_keeper(\"example_logs\", \"example_tensorboard\")\n", @@ -1563,7 +1545,8 @@ "tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, \n", " visualizer = umap.UMAP(), \n", " visualizer_hook = visualizer_hook,\n", - " dataloader_num_workers = 32)\n", + " dataloader_num_workers = 32,\n", + " accuracy_calculator=AccuracyCalculator(k=\"max_bin_count\"))\n", "\n", "end_of_epoch_hook = hooks.end_of_epoch_hook(tester, \n", " dataset_dict, \n", @@ -1577,8 +1560,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A0D3Jvxc5iWD", - "colab_type": "text" + "id": "A0D3Jvxc5iWD" }, "source": [ "## Create the trainer" @@ -1589,7 +1571,6 @@ "metadata": { "scrolled": false, "id": "DuASrVs-wk_X", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -1627,8 +1608,7 @@ { "cell_type": "markdown", "metadata": { - "id": "jbt_j2TzQ669", - "colab_type": "text" + "id": "jbt_j2TzQ669" }, "source": [ "## Start Tensorboard\n", @@ -1638,9 +1618,7 @@ { "cell_type": "code", "metadata": { - "id": "s6HCNGf2Q6_X", - "colab_type": "code", - "colab": {} + "id": "s6HCNGf2Q6_X" }, "source": [ "%load_ext tensorboard\n", @@ -1652,8 +1630,7 @@ { "cell_type": "markdown", "metadata": { - "id": "gIq7s7jf5ksj", - "colab_type": "text" + "id": "gIq7s7jf5ksj" }, "source": [ "## Train the model" @@ -1663,7 +1640,6 @@ "cell_type": "code", "metadata": { "id": "WHza2JJHwk_Z", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 diff --git a/examples/notebooks/DeepAdversarialMetricLearning.ipynb b/examples/notebooks/DeepAdversarialMetricLearning.ipynb index 2fe646e0..d4b40460 100644 --- a/examples/notebooks/DeepAdversarialMetricLearning.ipynb +++ b/examples/notebooks/DeepAdversarialMetricLearning.ipynb @@ -525,8 +525,7 @@ { "cell_type": "markdown", "metadata": { - "id": "f-1bIqrdFKiH", - "colab_type": "text" + "id": "f-1bIqrdFKiH" }, "source": [ "# PyTorch Metric Learning\n", @@ -537,8 +536,7 @@ { "cell_type": "markdown", "metadata": { - "id": "MKpRHvy24tV7", - "colab_type": "text" + "id": "MKpRHvy24tV7" }, "source": [ "## Install the necessary packages" @@ -548,7 +546,6 @@ "cell_type": "code", "metadata": { "id": "ZeIGxbbp3W2S", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 87 @@ -556,7 +553,8 @@ "outputId": "ae98f82a-68d7-46fe-8d3a-30f0c83b59cb" }, "source": [ - "!pip install -q pytorch-metric-learning[with-hooks]" + "!pip install -q pytorch-metric-learning[with-hooks]\n", + "!pip install umap-learn" ], "execution_count": null, "outputs": [ @@ -575,8 +573,7 @@ { "cell_type": "markdown", "metadata": { - "id": "BfqRRbIw4zYR", - "colab_type": "text" + "id": "BfqRRbIw4zYR" }, "source": [ "## Import the packages" @@ -586,7 +583,6 @@ "cell_type": "code", "metadata": { "id": "567qnmi7wk_M", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -598,6 +594,7 @@ "from pytorch_metric_learning import losses, miners, samplers, trainers, testers\n", "from pytorch_metric_learning.utils import common_functions\n", "import pytorch_metric_learning.utils.logging_presets as logging_presets\n", + "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n", "import numpy as np\n", "import torchvision\n", "from torchvision import datasets, transforms\n", @@ -628,8 +625,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Qxs6EEeR496q", - "colab_type": "text" + "id": "Qxs6EEeR496q" }, "source": [ "## Simple model def" @@ -638,9 +634,7 @@ { "cell_type": "code", "metadata": { - "id": "zKyR6gnTwk_P", - "colab_type": "code", - "colab": {} + "id": "zKyR6gnTwk_P" }, "source": [ "class MLP(nn.Module):\n", @@ -670,8 +664,7 @@ { "cell_type": "markdown", "metadata": { - "id": "btjxk6zR5Cl6", - "colab_type": "text" + "id": "btjxk6zR5Cl6" }, "source": [ "## Initialize models, optimizers and image transforms" @@ -681,7 +674,6 @@ "cell_type": "code", "metadata": { "id": "8tzmyFS3wk_R", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 84, @@ -769,8 +761,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Xf0xgdWS5GqG", - "colab_type": "text" + "id": "Xf0xgdWS5GqG" }, "source": [ "## Create the dataset and class-disjoint train/val splits" @@ -780,7 +771,6 @@ "cell_type": "code", "metadata": { "id": "D-nmnYYAwk_T", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 101, @@ -865,8 +855,7 @@ { "cell_type": "markdown", "metadata": { - "id": "r7J817Vs5LNs", - "colab_type": "text" + "id": "r7J817Vs5LNs" }, "source": [ "##Create the loss, miner, sampler, and package them into dictionaries\n" @@ -876,9 +865,7 @@ "cell_type": "code", "metadata": { "scrolled": false, - "id": "Kp9AC_4Dwk_V", - "colab_type": "code", - "colab": {} + "id": "Kp9AC_4Dwk_V" }, "source": [ "# Set the loss function\n", @@ -917,9 +904,7 @@ { "cell_type": "code", "metadata": { - "id": "hg0Gt4VXNCjO", - "colab_type": "code", - "colab": {} + "id": "hg0Gt4VXNCjO" }, "source": [ "# Remove logs if you want to train with new parameters\n", @@ -931,8 +916,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A9fa1VYD5Yv0", - "colab_type": "text" + "id": "A9fa1VYD5Yv0" }, "source": [ "## Create the training and testing hooks" @@ -942,7 +926,6 @@ "cell_type": "code", "metadata": { "id": "Vq_Pd7Pd5Xi_", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -970,7 +953,8 @@ "tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, \n", " visualizer = umap.UMAP(), \n", " visualizer_hook = visualizer_hook,\n", - " dataloader_num_workers = 32)\n", + " dataloader_num_workers = 32,\n", + " accuracy_calculator=AccuracyCalculator(k=\"max_bin_count\"))\n", "\n", "end_of_epoch_hook = hooks.end_of_epoch_hook(tester, \n", " dataset_dict, \n", @@ -992,8 +976,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A0D3Jvxc5iWD", - "colab_type": "text" + "id": "A0D3Jvxc5iWD" }, "source": [ "## Create the trainer" @@ -1004,7 +987,6 @@ "metadata": { "scrolled": false, "id": "DuASrVs-wk_X", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -1044,8 +1026,7 @@ { "cell_type": "markdown", "metadata": { - "id": "nBzB74hNSILg", - "colab_type": "text" + "id": "nBzB74hNSILg" }, "source": [ "## Start Tensorboard\n", @@ -1055,9 +1036,7 @@ { "cell_type": "code", "metadata": { - "id": "aoJYXjOzSISK", - "colab_type": "code", - "colab": {} + "id": "aoJYXjOzSISK" }, "source": [ "%load_ext tensorboard\n", @@ -1069,8 +1048,7 @@ { "cell_type": "markdown", "metadata": { - "id": "gIq7s7jf5ksj", - "colab_type": "text" + "id": "gIq7s7jf5ksj" }, "source": [ "## Train the model" @@ -1080,7 +1058,6 @@ "cell_type": "code", "metadata": { "id": "WHza2JJHwk_Z", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 diff --git a/examples/notebooks/MetricLossOnly.ipynb b/examples/notebooks/MetricLossOnly.ipynb index afb09e50..e4e33ae5 100644 --- a/examples/notebooks/MetricLossOnly.ipynb +++ b/examples/notebooks/MetricLossOnly.ipynb @@ -524,8 +524,7 @@ { "cell_type": "markdown", "metadata": { - "id": "6-P8A5wm-i5O", - "colab_type": "text" + "id": "6-P8A5wm-i5O" }, "source": [ "# PyTorch Metric Learning\n", @@ -536,8 +535,7 @@ { "cell_type": "markdown", "metadata": { - "id": "MKpRHvy24tV7", - "colab_type": "text" + "id": "MKpRHvy24tV7" }, "source": [ "## Install the necessary packages" @@ -547,7 +545,6 @@ "cell_type": "code", "metadata": { "id": "ZeIGxbbp3W2S", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 87 @@ -555,7 +552,8 @@ "outputId": "4cbced5a-f22d-4a15-b4eb-20177f019ad8" }, "source": [ - "!pip install -q pytorch-metric-learning[with-hooks]" + "!pip install -q pytorch-metric-learning[with-hooks]\n", + "!pip install umap-learn" ], "execution_count": null, "outputs": [ @@ -574,8 +572,7 @@ { "cell_type": "markdown", "metadata": { - "id": "BfqRRbIw4zYR", - "colab_type": "text" + "id": "BfqRRbIw4zYR" }, "source": [ "## Import the packages" @@ -585,7 +582,6 @@ "cell_type": "code", "metadata": { "id": "567qnmi7wk_M", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -597,6 +593,7 @@ "from pytorch_metric_learning import losses, miners, samplers, trainers, testers\n", "from pytorch_metric_learning.utils import common_functions\n", "import pytorch_metric_learning.utils.logging_presets as logging_presets\n", + "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n", "import numpy as np\n", "import torchvision\n", "from torchvision import datasets, transforms\n", @@ -627,8 +624,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Qxs6EEeR496q", - "colab_type": "text" + "id": "Qxs6EEeR496q" }, "source": [ "## Simple model def" @@ -637,9 +633,7 @@ { "cell_type": "code", "metadata": { - "id": "zKyR6gnTwk_P", - "colab_type": "code", - "colab": {} + "id": "zKyR6gnTwk_P" }, "source": [ "class MLP(nn.Module):\n", @@ -669,8 +663,7 @@ { "cell_type": "markdown", "metadata": { - "id": "btjxk6zR5Cl6", - "colab_type": "text" + "id": "btjxk6zR5Cl6" }, "source": [ "## Initialize models, optimizers and image transforms" @@ -680,7 +673,6 @@ "cell_type": "code", "metadata": { "id": "8tzmyFS3wk_R", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 84, @@ -764,8 +756,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Xf0xgdWS5GqG", - "colab_type": "text" + "id": "Xf0xgdWS5GqG" }, "source": [ "## Create the dataset and class-disjoint train/val splits" @@ -775,7 +766,6 @@ "cell_type": "code", "metadata": { "id": "D-nmnYYAwk_T", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 101, @@ -860,8 +850,7 @@ { "cell_type": "markdown", "metadata": { - "id": "r7J817Vs5LNs", - "colab_type": "text" + "id": "r7J817Vs5LNs" }, "source": [ "##Create the loss, miner, sampler, and package them into dictionaries\n" @@ -871,9 +860,7 @@ "cell_type": "code", "metadata": { "scrolled": false, - "id": "Kp9AC_4Dwk_V", - "colab_type": "code", - "colab": {} + "id": "Kp9AC_4Dwk_V" }, "source": [ "# Set the loss function\n", @@ -901,9 +888,7 @@ { "cell_type": "code", "metadata": { - "id": "6XfLY1vHMvwM", - "colab_type": "code", - "colab": {} + "id": "6XfLY1vHMvwM" }, "source": [ "# Remove logs if you want to train with new parameters\n", @@ -915,8 +900,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A9fa1VYD5Yv0", - "colab_type": "text" + "id": "A9fa1VYD5Yv0" }, "source": [ "## Create the training and testing hooks" @@ -926,7 +910,6 @@ "cell_type": "code", "metadata": { "id": "Vq_Pd7Pd5Xi_", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -954,7 +937,8 @@ "tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, \n", " visualizer = umap.UMAP(), \n", " visualizer_hook = visualizer_hook,\n", - " dataloader_num_workers = 32)\n", + " dataloader_num_workers = 32,\n", + " accuracy_calculator=AccuracyCalculator(k=\"max_bin_count\"))\n", "\n", "end_of_epoch_hook = hooks.end_of_epoch_hook(tester, \n", " dataset_dict, \n", @@ -976,8 +960,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A0D3Jvxc5iWD", - "colab_type": "text" + "id": "A0D3Jvxc5iWD" }, "source": [ "## Create the trainer" @@ -987,9 +970,7 @@ "cell_type": "code", "metadata": { "scrolled": false, - "id": "DuASrVs-wk_X", - "colab_type": "code", - "colab": {} + "id": "DuASrVs-wk_X" }, "source": [ "trainer = trainers.MetricLossOnly(models,\n", @@ -1009,8 +990,7 @@ { "cell_type": "markdown", "metadata": { - "id": "yl_z0VoXIKrk", - "colab_type": "text" + "id": "yl_z0VoXIKrk" }, "source": [ "## Start Tensorboard\n", @@ -1020,9 +1000,7 @@ { "cell_type": "code", "metadata": { - "id": "35aR2k0zIK4V", - "colab_type": "code", - "colab": {} + "id": "35aR2k0zIK4V" }, "source": [ "%load_ext tensorboard\n", @@ -1034,8 +1012,7 @@ { "cell_type": "markdown", "metadata": { - "id": "gIq7s7jf5ksj", - "colab_type": "text" + "id": "gIq7s7jf5ksj" }, "source": [ "## Train the model" @@ -1045,7 +1022,6 @@ "cell_type": "code", "metadata": { "id": "WHza2JJHwk_Z", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 diff --git a/examples/notebooks/TrainWithClassifier.ipynb b/examples/notebooks/TrainWithClassifier.ipynb index 4e8ee4c7..d43be647 100644 --- a/examples/notebooks/TrainWithClassifier.ipynb +++ b/examples/notebooks/TrainWithClassifier.ipynb @@ -31,8 +31,7 @@ { "cell_type": "markdown", "metadata": { - "id": "4pR7wq97EtSE", - "colab_type": "text" + "id": "4pR7wq97EtSE" }, "source": [ "# PyTorch Metric Learning\n", @@ -43,8 +42,7 @@ { "cell_type": "markdown", "metadata": { - "id": "MKpRHvy24tV7", - "colab_type": "text" + "id": "MKpRHvy24tV7" }, "source": [ "## Install the necessary packages" @@ -53,12 +51,11 @@ { "cell_type": "code", "metadata": { - "id": "ZeIGxbbp3W2S", - "colab_type": "code", - "colab": {} + "id": "ZeIGxbbp3W2S" }, "source": [ - "!pip install -q pytorch-metric-learning[with-hooks]" + "!pip install -q pytorch-metric-learning[with-hooks]\n", + "!pip install umap-learn" ], "execution_count": null, "outputs": [] @@ -66,8 +63,7 @@ { "cell_type": "markdown", "metadata": { - "id": "BfqRRbIw4zYR", - "colab_type": "text" + "id": "BfqRRbIw4zYR" }, "source": [ "## Import the packages" @@ -77,7 +73,6 @@ "cell_type": "code", "metadata": { "id": "567qnmi7wk_M", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -89,6 +84,7 @@ "from pytorch_metric_learning import losses, miners, samplers, trainers, testers\n", "from pytorch_metric_learning.utils import common_functions\n", "import pytorch_metric_learning.utils.logging_presets as logging_presets\n", + "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n", "import numpy as np\n", "import torchvision\n", "from torchvision import datasets, transforms\n", @@ -119,8 +115,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Qxs6EEeR496q", - "colab_type": "text" + "id": "Qxs6EEeR496q" }, "source": [ "## Simple model def" @@ -129,9 +124,7 @@ { "cell_type": "code", "metadata": { - "id": "zKyR6gnTwk_P", - "colab_type": "code", - "colab": {} + "id": "zKyR6gnTwk_P" }, "source": [ "class MLP(nn.Module):\n", @@ -161,8 +154,7 @@ { "cell_type": "markdown", "metadata": { - "id": "btjxk6zR5Cl6", - "colab_type": "text" + "id": "btjxk6zR5Cl6" }, "source": [ "## Initialize models, optimizers and image transforms" @@ -171,9 +163,7 @@ { "cell_type": "code", "metadata": { - "id": "8tzmyFS3wk_R", - "colab_type": "code", - "colab": {} + "id": "8tzmyFS3wk_R" }, "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", @@ -217,8 +207,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Xf0xgdWS5GqG", - "colab_type": "text" + "id": "Xf0xgdWS5GqG" }, "source": [ "## Create the dataset and class-disjoint train/val splits" @@ -228,7 +217,6 @@ "cell_type": "code", "metadata": { "id": "D-nmnYYAwk_T", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 52 @@ -280,8 +268,7 @@ { "cell_type": "markdown", "metadata": { - "id": "r7J817Vs5LNs", - "colab_type": "text" + "id": "r7J817Vs5LNs" }, "source": [ "##Create the loss, miner, sampler, and package them into dictionaries\n" @@ -291,9 +278,7 @@ "cell_type": "code", "metadata": { "scrolled": false, - "id": "Kp9AC_4Dwk_V", - "colab_type": "code", - "colab": {} + "id": "Kp9AC_4Dwk_V" }, "source": [ "# Set the loss function\n", @@ -327,9 +312,7 @@ { "cell_type": "code", "metadata": { - "id": "sXMP1gI1Sdam", - "colab_type": "code", - "colab": {} + "id": "sXMP1gI1Sdam" }, "source": [ "# Remove logs if you want to train with new parameters\n", @@ -341,8 +324,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A9fa1VYD5Yv0", - "colab_type": "text" + "id": "A9fa1VYD5Yv0" }, "source": [ "## Create the training and testing hooks" @@ -351,9 +333,7 @@ { "cell_type": "code", "metadata": { - "id": "Vq_Pd7Pd5Xi_", - "colab_type": "code", - "colab": {} + "id": "Vq_Pd7Pd5Xi_" }, "source": [ "record_keeper, _, _ = logging_presets.get_record_keeper(\"example_logs\", \"example_tensorboard\")\n", @@ -376,7 +356,8 @@ "tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, \n", " visualizer = umap.UMAP(), \n", " visualizer_hook = visualizer_hook,\n", - " dataloader_num_workers = 32)\n", + " dataloader_num_workers = 32,\n", + " accuracy_calculator=AccuracyCalculator(k=\"max_bin_count\"))\n", "\n", "end_of_epoch_hook = hooks.end_of_epoch_hook(tester, \n", " dataset_dict, \n", @@ -390,8 +371,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A0D3Jvxc5iWD", - "colab_type": "text" + "id": "A0D3Jvxc5iWD" }, "source": [ "## Create the trainer" @@ -401,9 +381,7 @@ "cell_type": "code", "metadata": { "scrolled": false, - "id": "DuASrVs-wk_X", - "colab_type": "code", - "colab": {} + "id": "DuASrVs-wk_X" }, "source": [ "trainer = trainers.TrainWithClassifier(models,\n", @@ -424,8 +402,7 @@ { "cell_type": "markdown", "metadata": { - "id": "GIlMMUxSPLni", - "colab_type": "text" + "id": "GIlMMUxSPLni" }, "source": [ "## Start Tensorboard\n", @@ -435,9 +412,7 @@ { "cell_type": "code", "metadata": { - "id": "ikIOmWNNPLtg", - "colab_type": "code", - "colab": {} + "id": "ikIOmWNNPLtg" }, "source": [ "%load_ext tensorboard\n", @@ -449,8 +424,7 @@ { "cell_type": "markdown", "metadata": { - "id": "gIq7s7jf5ksj", - "colab_type": "text" + "id": "gIq7s7jf5ksj" }, "source": [ "## Train the model" @@ -460,7 +434,6 @@ "cell_type": "code", "metadata": { "id": "WHza2JJHwk_Z", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 diff --git a/examples/notebooks/TwoStreamMetricLoss.ipynb b/examples/notebooks/TwoStreamMetricLoss.ipynb index c2e3b53d..19bbd989 100644 --- a/examples/notebooks/TwoStreamMetricLoss.ipynb +++ b/examples/notebooks/TwoStreamMetricLoss.ipynb @@ -512,8 +512,7 @@ { "cell_type": "markdown", "metadata": { - "id": "nYDCtu679yz1", - "colab_type": "text" + "id": "nYDCtu679yz1" }, "source": [ "# PyTorch Metric Learning\n", @@ -524,8 +523,7 @@ { "cell_type": "markdown", "metadata": { - "id": "ChQnqzgP9tPj", - "colab_type": "text" + "id": "ChQnqzgP9tPj" }, "source": [ "## Install prereqs" @@ -535,7 +533,6 @@ "cell_type": "code", "metadata": { "id": "nLr_MM936Wd5", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 87 @@ -543,7 +540,8 @@ "outputId": "f8d01c9e-4943-4128-d34d-d16b76a737c2" }, "source": [ - "!pip install -q pytorch-metric-learning[with-hooks]" + "!pip install -q pytorch-metric-learning[with-hooks]\n", + "!pip install umap-learn" ], "execution_count": null, "outputs": [ @@ -562,8 +560,7 @@ { "cell_type": "markdown", "metadata": { - "id": "PyABRGhcXrei", - "colab_type": "text" + "id": "PyABRGhcXrei" }, "source": [ "## Import the packages" @@ -573,7 +570,6 @@ "cell_type": "code", "metadata": { "id": "VULYLloy9ivc", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 @@ -585,6 +581,7 @@ "from pytorch_metric_learning import losses, miners, samplers, trainers, testers\n", "from pytorch_metric_learning.utils import common_functions\n", "import pytorch_metric_learning.utils.logging_presets as logging_presets\n", + "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n", "import numpy as np\n", "import torchvision\n", "from torchvision import datasets, transforms\n", @@ -614,8 +611,7 @@ { "cell_type": "markdown", "metadata": { - "id": "2bT2tUtC9mND", - "colab_type": "text" + "id": "2bT2tUtC9mND" }, "source": [ "## Create two-stream dataset from CIFAR100" @@ -624,9 +620,7 @@ { "cell_type": "code", "metadata": { - "id": "VOWPA22W9lX0", - "colab_type": "code", - "colab": {} + "id": "VOWPA22W9lX0" }, "source": [ "class CIFAR100TwoStreamDataset(torch.utils.data.Dataset):\n", @@ -661,8 +655,7 @@ { "cell_type": "markdown", "metadata": { - "id": "b79C0ZqXXyKx", - "colab_type": "text" + "id": "b79C0ZqXXyKx" }, "source": [ "##Simple model def" @@ -671,9 +664,7 @@ { "cell_type": "code", "metadata": { - "id": "THzceJtN_p78", - "colab_type": "code", - "colab": {} + "id": "THzceJtN_p78" }, "source": [ "class MLP(nn.Module):\n", @@ -703,8 +694,7 @@ { "cell_type": "markdown", "metadata": { - "id": "LTiWoweRX1T5", - "colab_type": "text" + "id": "LTiWoweRX1T5" }, "source": [ "## Initialize models, optimizers and image transforms" @@ -714,7 +704,6 @@ "cell_type": "code", "metadata": { "id": "1A6ad-I7_smx", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 84, @@ -795,8 +784,7 @@ { "cell_type": "markdown", "metadata": { - "id": "In_aiu4BX9zF", - "colab_type": "text" + "id": "In_aiu4BX9zF" }, "source": [ "## Initialize the datasets" @@ -806,7 +794,6 @@ "cell_type": "code", "metadata": { "id": "llktnjZJX8cN", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 101, @@ -873,8 +860,7 @@ { "cell_type": "markdown", "metadata": { - "id": "SG5rnOxHYGQ_", - "colab_type": "text" + "id": "SG5rnOxHYGQ_" }, "source": [ "## Create the loss, miner, sampler, and package them into dictionaries" @@ -883,9 +869,7 @@ { "cell_type": "code", "metadata": { - "id": "B9Q7B0OKXkil", - "colab_type": "code", - "colab": {} + "id": "B9Q7B0OKXkil" }, "source": [ "# Set the loss function\n", @@ -913,9 +897,7 @@ { "cell_type": "code", "metadata": { - "id": "5wPolfqrXZ4I", - "colab_type": "code", - "colab": {} + "id": "5wPolfqrXZ4I" }, "source": [ "# Remove logs if you want to train with new parameters\n", @@ -927,8 +909,7 @@ { "cell_type": "markdown", "metadata": { - "id": "BLyl-1i_YJpJ", - "colab_type": "text" + "id": "BLyl-1i_YJpJ" }, "source": [ "## Create the training and testing hooks" @@ -937,9 +918,7 @@ { "cell_type": "code", "metadata": { - "id": "i4qQJLET_0iB", - "colab_type": "code", - "colab": {} + "id": "i4qQJLET_0iB" }, "source": [ "record_keeper, _, _ = logging_presets.get_record_keeper(\"example_logs\", \"example_tensorboard\")\n", @@ -967,7 +946,8 @@ "tester = testers.GlobalTwoStreamEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, \n", " visualizer = umap.UMAP(n_neighbors=50), \n", " visualizer_hook = visualizer_hook,\n", - " dataloader_num_workers = 32)\n", + " dataloader_num_workers = 32,\n", + " accuracy_calculator=AccuracyCalculator(k=\"max_bin_count\"))\n", "\n", "end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)" ], @@ -977,8 +957,7 @@ { "cell_type": "markdown", "metadata": { - "id": "8Fu2GPXoYSkA", - "colab_type": "text" + "id": "8Fu2GPXoYSkA" }, "source": [ "## Create the trainer" @@ -987,9 +966,7 @@ { "cell_type": "code", "metadata": { - "id": "R4IYl6dEYQFi", - "colab_type": "code", - "colab": {} + "id": "R4IYl6dEYQFi" }, "source": [ "trainer = trainers.TwoStreamMetricLoss(models,\n", @@ -1010,8 +987,7 @@ { "cell_type": "markdown", "metadata": { - "id": "LCjFfzIgODQ0", - "colab_type": "text" + "id": "LCjFfzIgODQ0" }, "source": [ "## Start Tensorboard\n", @@ -1021,9 +997,7 @@ { "cell_type": "code", "metadata": { - "id": "J7JvXuskODXY", - "colab_type": "code", - "colab": {} + "id": "J7JvXuskODXY" }, "source": [ "%load_ext tensorboard\n", @@ -1035,8 +1009,7 @@ { "cell_type": "markdown", "metadata": { - "id": "I20LpzvEYT_y", - "colab_type": "text" + "id": "I20LpzvEYT_y" }, "source": [ "## Train the model" @@ -1046,7 +1019,6 @@ "cell_type": "code", "metadata": { "id": "_Zs2-DveYQfW", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index c6ddc6d6..4ae67c0f 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "0.9.97" +__version__ = "0.9.98" diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 26596572..54911ae2 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -22,5 +22,6 @@ from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss from .soft_triple_loss import SoftTripleLoss from .sphereface_loss import SphereFaceLoss +from .supcon_loss import SupConLoss from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss diff --git a/src/pytorch_metric_learning/losses/supcon_loss.py b/src/pytorch_metric_learning/losses/supcon_loss.py new file mode 100644 index 00000000..1fb32b6b --- /dev/null +++ b/src/pytorch_metric_learning/losses/supcon_loss.py @@ -0,0 +1,45 @@ +from ..distances import CosineSimilarity +from ..reducers import AvgNonZeroReducer +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .generic_pair_loss import GenericPairLoss + + +# adapted from https://github.com/HobbitLong/SupContrast +class SupConLoss(GenericPairLoss): + def __init__(self, temperature=0.1, **kwargs): + super().__init__(mat_based_loss=True, **kwargs) + self.temperature = temperature + self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) + + def _compute_loss(self, mat, pos_mask, neg_mask): + if pos_mask.bool().any() and neg_mask.bool().any(): + # if dealing with actual distances, use negative distances + if not self.distance.is_inverted: + mat = -mat + mat = mat / self.temperature + mat_max, _ = mat.max(dim=1, keepdim=True) + mat = mat - mat_max.detach() # for numerical stability + + denominator = lmu.logsumexp( + mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 + ) + log_prob = mat - denominator + mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / ( + pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) + ) + + return { + "loss": { + "losses": -mean_log_prob_pos, + "indices": c_f.torch_arange_from_size(mat), + "reduction_type": "element", + } + } + return self.zero_losses() + + def get_default_reducer(self): + return AvgNonZeroReducer() + + def get_default_distance(self): + return CosineSimilarity() diff --git a/src/pytorch_metric_learning/testers/base_tester.py b/src/pytorch_metric_learning/testers/base_tester.py index 624b4650..4154b950 100644 --- a/src/pytorch_metric_learning/testers/base_tester.py +++ b/src/pytorch_metric_learning/testers/base_tester.py @@ -273,7 +273,6 @@ def test( embedder_model=None, splits_to_eval=None, collate_fn=None, - **kwargs, ): logging.info("Evaluating epoch {}".format(epoch)) if embedder_model is None: diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index 85e9d79f..5d4f83a5 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -153,10 +153,7 @@ def get_label_match_counts(query_labels, reference_labels, label_comparison_fn): label_comparison_fn(label_a, reference_labels) ) - # faiss can only do a max of k=1024, and we have to do k+1 - num_k = int(min(1023, torch.max(match_counts))) - - return (unique_query_labels, match_counts), num_k + return (unique_query_labels, match_counts) def get_lone_query_labels( @@ -214,6 +211,11 @@ def __init__( self.original_function_dict = self.get_function_dict(include, exclude) self.curr_function_dict = self.get_function_dict() self.avg_of_avgs = avg_of_avgs + + if (not (isinstance(k, int) and k > 0)) and (k not in [None, "max_bin_count"]): + raise ValueError( + "k must be an integer greater than 0, or None, or 'max_bin_count'" + ) self.k = k if label_comparison_fn: @@ -359,10 +361,6 @@ def get_accuracy( include=(), exclude=(), ): - embeddings_come_from_same_source = embeddings_come_from_same_source or ( - query is reference - ) - [query, reference, query_labels, reference_labels] = [ c_f.numpy_to_torch(x) for x in [query, reference, query_labels, reference_labels] @@ -380,7 +378,7 @@ def get_accuracy( } if any(x in self.requires_knn() for x in self.get_curr_metrics()): - label_counts, num_k = get_label_match_counts( + label_counts = get_label_match_counts( query_labels, reference_labels, self.label_comparison_fn ) lone_query_labels, not_lone_query_mask = get_lone_query_labels( @@ -390,8 +388,10 @@ def get_accuracy( self.label_comparison_fn, ) - if self.k is not None: - num_k = self.k + num_k = self.determine_k( + label_counts[1], len(reference), embeddings_come_from_same_source + ) + knn_indices, knn_distances = stat_utils.get_knn( reference, query, num_k, embeddings_come_from_same_source ) @@ -427,5 +427,15 @@ def check_primary_metrics(calc, include=(), exclude=()): ) ) + def determine_k( + self, bin_counts, num_reference_embeddings, embeddings_come_from_same_source + ): + self_count = int(embeddings_come_from_same_source) + if self.k == "max_bin_count": + return torch.max(bin_counts).item() - self_count + if self.k is None: + return num_reference_embeddings - self_count + return self.k + def description(self): return "avg_of_avgs" if self.avg_of_avgs else "" diff --git a/src/pytorch_metric_learning/utils/common_functions.py b/src/pytorch_metric_learning/utils/common_functions.py index 3859d60a..2ed48957 100644 --- a/src/pytorch_metric_learning/utils/common_functions.py +++ b/src/pytorch_metric_learning/utils/common_functions.py @@ -458,9 +458,10 @@ def torch_standard_scaler(x): def to_dtype(x, tensor=None, dtype=None): - dt = dtype if dtype is not None else tensor.dtype - if x.dtype != dt: - x = x.type(dt) + if not torch.is_autocast_enabled(): + dt = dtype if dtype is not None else tensor.dtype + if x.dtype != dt: + x = x.type(dt) return x diff --git a/src/pytorch_metric_learning/utils/loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/loss_and_miner_utils.py index a077861c..41b92cf7 100644 --- a/src/pytorch_metric_learning/utils/loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/loss_and_miner_utils.py @@ -191,29 +191,9 @@ def convert_to_triplets(indices_tuple, labels, t_per_anchor=100): elif len(indices_tuple) == 3: return indices_tuple else: - a_out, p_out, n_out = [], [], [] a1, p, a2, n = indices_tuple - empty_output = [torch.tensor([], device=labels.device)] * 3 - if len(a1) == 0 or len(a2) == 0: - return empty_output - for i in range(len(labels)): - pos_idx = torch.where(a1 == i)[0] - neg_idx = torch.where(a2 == i)[0] - if len(pos_idx) > 0 and len(neg_idx) > 0: - p_idx = p[pos_idx] - n_idx = n[neg_idx] - p_idx, n_idx = matched_size_indices(p_idx, n_idx) - a_idx = torch.ones_like(c_f.longest_list([p_idx, n_idx])) * i - a_out.append(a_idx) - p_out.append(p_idx) - n_out.append(n_idx) - try: - return [torch.cat(x, dim=0) for x in [a_out, p_out, n_out]] - except RuntimeError: - # assert that the exception was caused by disjoint a1 and a2 - # otherwise something has gone wrong - assert len(np.intersect1d(a1, a2)) == 0 - return empty_output + p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2) + return a1[p_idx], p[p_idx], n[n_idx] def convert_to_weights(indices_tuple, labels, dtype): diff --git a/src/pytorch_metric_learning/utils/stat_utils.py b/src/pytorch_metric_learning/utils/stat_utils.py index ef233051..7d88d794 100644 --- a/src/pytorch_metric_learning/utils/stat_utils.py +++ b/src/pytorch_metric_learning/utils/stat_utils.py @@ -14,10 +14,39 @@ from . import common_functions as c_f +def add_to_index_and_search(index, reference_embeddings, test_embeddings, k): + index.add(reference_embeddings) + return index.search(test_embeddings, k) + + +def try_gpu(cpu_index, reference_embeddings, test_embeddings, k): + # https://github.com/facebookresearch/faiss/blob/master/faiss/gpu/utils/DeviceDefs.cuh + gpu_index = None + gpus_are_available = faiss.get_num_gpus() > 0 + if gpus_are_available: + max_k_for_gpu = 1024 if float(torch.version.cuda) < 9.5 else 2048 + if k <= max_k_for_gpu: + gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) + try: + return add_to_index_and_search( + gpu_index, reference_embeddings, test_embeddings, k + ) + except (AttributeError, RuntimeError) as e: + if gpus_are_available: + logging.warning( + f"Using CPU for k-nn search because k = {k} > {max_k_for_gpu}, which is the maximum allowable on GPU." + ) + return add_to_index_and_search( + cpu_index, reference_embeddings, test_embeddings, k + ) + + # modified from https://github.com/facebookresearch/deepcluster def get_knn( reference_embeddings, test_embeddings, k, embeddings_come_from_same_source=False ): + if embeddings_come_from_same_source: + k = k + 1 device = reference_embeddings.device reference_embeddings = c_f.to_numpy(reference_embeddings).astype(np.float32) test_embeddings = c_f.to_numpy(test_embeddings).astype(np.float32) @@ -25,16 +54,13 @@ def get_knn( d = reference_embeddings.shape[1] logging.info("running k-nn with k=%d" % k) logging.info("embedding dimensionality is %d" % d) - index = faiss.IndexFlatL2(d) - if faiss.get_num_gpus() > 0: - index = faiss.index_cpu_to_all_gpus(index) - index.add(reference_embeddings) - distances, indices = index.search(test_embeddings, k + 1) + cpu_index = faiss.IndexFlatL2(d) + distances, indices = try_gpu(cpu_index, reference_embeddings, test_embeddings, k) distances = c_f.to_device(torch.from_numpy(distances), device=device) indices = c_f.to_device(torch.from_numpy(indices), device=device) if embeddings_come_from_same_source: return indices[:, 1:], distances[:, 1:] - return indices[:, :k], distances[:, :k] + return indices, distances # modified from https://raw.githubusercontent.com/facebookresearch/deepcluster/ diff --git a/tests/losses/test_ntxent_loss.py b/tests/losses/test_ntxent_loss.py index 1bda4201..718fb561 100644 --- a/tests/losses/test_ntxent_loss.py +++ b/tests/losses/test_ntxent_loss.py @@ -3,8 +3,8 @@ import torch from pytorch_metric_learning.distances import LpDistance -from pytorch_metric_learning.losses import NTXentLoss -from pytorch_metric_learning.reducers import PerAnchorReducer +from pytorch_metric_learning.losses import NTXentLoss, SupConLoss +from pytorch_metric_learning.reducers import AvgNonZeroReducer, PerAnchorReducer from pytorch_metric_learning.utils import common_functions as c_f from .. import TEST_DEVICE, TEST_DTYPES @@ -15,10 +15,14 @@ def test_ntxent_loss(self): temperature = 0.1 loss_funcA = NTXentLoss(temperature=temperature) loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance()) - loss_funcC = NTXentLoss(temperature=temperature, reducer=PerAnchorReducer()) + loss_funcC = NTXentLoss( + temperature=temperature, reducer=PerAnchorReducer(AvgNonZeroReducer()) + ) + loss_funcD = SupConLoss(temperature=temperature) + loss_funcE = SupConLoss(temperature=temperature, distance=LpDistance()) for dtype in TEST_DTYPES: - embedding_angles = [0, 20, 40, 60, 80, 100] + embedding_angles = [0, 10, 20, 50, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, @@ -29,9 +33,10 @@ def test_ntxent_loss(self): labels = torch.LongTensor([0, 0, 0, 1, 1, 2]) - lossA = loss_funcA(embeddings, labels) - lossB = loss_funcB(embeddings, labels) - lossC = loss_funcC(embeddings, labels) + obtained_losses = [ + x(embeddings, labels) + for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE] + ] pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)] neg_pairs = [ @@ -59,10 +64,12 @@ def test_ntxent_loss(self): (5, 4), ] - total_lossA, total_lossB, total_lossC = ( + total_lossA, total_lossB, total_lossC, total_lossD, total_lossE = ( 0, 0, - torch.zeros(6, device=TEST_DEVICE, dtype=dtype), + torch.zeros(5, device=TEST_DEVICE, dtype=dtype), + torch.zeros(5, device=TEST_DEVICE, dtype=dtype), + torch.zeros(5, device=TEST_DEVICE, dtype=dtype), ) for a1, p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] @@ -72,69 +79,95 @@ def test_ntxent_loss(self): ) denominatorA = numeratorA.clone() denominatorB = numeratorB.clone() - for a2, n in neg_pairs: + denominatorD = 0 + denominatorE = 0 + for a2, n in pos_pairs + neg_pairs: if a2 == a1: negative = embeddings[n] + curr_denomD = torch.exp( + torch.matmul(anchor, negative) / temperature + ) + curr_denomE = torch.exp( + -torch.sqrt(torch.sum((anchor - negative) ** 2)) + / temperature + ) + denominatorD += curr_denomD + denominatorE += curr_denomE + if (a2, n) not in pos_pairs: + denominatorA += curr_denomD + denominatorB += curr_denomE else: continue - denominatorA += torch.exp( - torch.matmul(anchor, negative) / temperature - ) - denominatorB += torch.exp( - -torch.sqrt(torch.sum((anchor - negative) ** 2)) / temperature - ) + curr_lossA = -torch.log(numeratorA / denominatorA) curr_lossB = -torch.log(numeratorB / denominatorB) + curr_lossD = -torch.log(numeratorA / denominatorD) + curr_lossE = -torch.log(numeratorB / denominatorE) total_lossA += curr_lossA total_lossB += curr_lossB total_lossC[a1] += curr_lossA + total_lossD[a1] += curr_lossD + total_lossE[a1] += curr_lossE total_lossA /= len(pos_pairs) total_lossB /= len(pos_pairs) pos_pair_per_anchor = torch.tensor( - [2, 2, 2, 1, 1, 0], device=TEST_DEVICE, dtype=dtype + [2, 2, 2, 1, 1], device=TEST_DEVICE, dtype=dtype ) - total_lossC = total_lossC / pos_pair_per_anchor - total_lossC[pos_pair_per_anchor == 0] = 0 - - total_lossC = torch.mean(total_lossC) + total_lossC, total_lossD, total_lossE = [ + torch.mean(x / pos_pair_per_anchor) + for x in [total_lossC, total_lossD, total_lossE] + ] rtol = 1e-2 if dtype == torch.float16 else 1e-5 - self.assertTrue(torch.isclose(lossA, total_lossA, rtol=rtol)) - self.assertTrue(torch.isclose(lossB, total_lossB, rtol=rtol)) - self.assertTrue(torch.isclose(lossC, total_lossC, rtol=rtol)) + self.assertTrue(torch.isclose(obtained_losses[0], total_lossA, rtol=rtol)) + self.assertTrue(torch.isclose(obtained_losses[1], total_lossB, rtol=rtol)) + self.assertTrue(torch.isclose(obtained_losses[2], total_lossC, rtol=rtol)) + self.assertTrue(torch.isclose(obtained_losses[3], total_lossD, rtol=rtol)) + self.assertTrue(torch.isclose(obtained_losses[4], total_lossE, rtol=rtol)) def test_with_no_valid_pairs(self): - loss_func = NTXentLoss(temperature=0.1) - all_embedding_angles = [[0], [0, 10, 20]] - all_labels = [torch.LongTensor([0]), torch.LongTensor([0, 0, 0])] - for dtype in TEST_DTYPES: - for embedding_angles, labels in zip(all_embedding_angles, all_labels): - embeddings = torch.tensor( - [c_f.angle_to_coord(a) for a in embedding_angles], - requires_grad=True, - dtype=dtype, - ).to( - TEST_DEVICE - ) # 2D embeddings - loss = loss_func(embeddings, labels) - loss.backward() - self.assertEqual(loss, 0) + all_embedding_angles = [[0], [0, 10, 20], [0, 40, 60]] + all_labels = [ + torch.LongTensor([0]), + torch.LongTensor([0, 0, 0]), + torch.LongTensor([1, 2, 3]), + ] + temperature = 0.1 + for loss_class in [NTXentLoss, SupConLoss]: + loss_funcA = loss_class(temperature) + loss_funcB = loss_class(temperature, distance=LpDistance()) + for loss_func in [loss_funcA, loss_funcB]: + for dtype in TEST_DTYPES: + for embedding_angles, labels in zip( + all_embedding_angles, all_labels + ): + embeddings = torch.tensor( + [c_f.angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + loss = loss_func(embeddings, labels) + loss.backward() + self.assertEqual(loss, 0) def test_backward(self): temperature = 0.1 - loss_funcA = NTXentLoss(temperature=temperature) - loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance()) - for dtype in TEST_DTYPES: - for loss_func in [loss_funcA, loss_funcB]: - embedding_angles = [0, 20, 40, 60, 80] - embeddings = torch.tensor( - [c_f.angle_to_coord(a) for a in embedding_angles], - requires_grad=True, - dtype=dtype, - ).to( - TEST_DEVICE - ) # 2D embeddings - labels = torch.LongTensor([0, 0, 1, 1, 2]) - loss = loss_func(embeddings, labels) - loss.backward() + for loss_class in [NTXentLoss, SupConLoss]: + loss_funcA = loss_class(temperature) + loss_funcB = loss_class(temperature, distance=LpDistance()) + for dtype in TEST_DTYPES: + for loss_func in [loss_funcA, loss_funcB]: + embedding_angles = [0, 20, 40, 60, 80] + embeddings = torch.tensor( + [c_f.angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([0, 0, 1, 1, 2]) + loss = loss_func(embeddings, labels) + loss.backward() diff --git a/tests/losses/test_proxy_anchor_loss.py b/tests/losses/test_proxy_anchor_loss.py index d36c611d..2af18366 100644 --- a/tests/losses/test_proxy_anchor_loss.py +++ b/tests/losses/test_proxy_anchor_loss.py @@ -5,6 +5,8 @@ # This code is copied directly from the official implementation # so that we can make sure our implementation returns the same result. # It's copied under the MIT license. +from contextlib import nullcontext + import torch import torch.nn as nn import torch.nn.functional as F @@ -85,34 +87,50 @@ def test_proxyanchor_loss(self): num_classes = 10 embedding_size = 2 margin = 0.5 - for dtype in TEST_DTYPES: - alpha = 1 if dtype == torch.float16 else 32 - loss_func = ProxyAnchorLoss( - num_classes, embedding_size, margin=margin, alpha=alpha - ).to(TEST_DEVICE) - original_loss_func = OriginalImplementationProxyAnchor( - num_classes, embedding_size, mrg=margin, alpha=alpha - ).to(TEST_DEVICE) - original_loss_func.proxies.data = original_loss_func.proxies.data.type( - dtype - ) - loss_func.proxies = original_loss_func.proxies - - embedding_angles = list(range(0, 180)) - embeddings = torch.tensor( - [c_f.angle_to_coord(a) for a in embedding_angles], - requires_grad=True, - dtype=dtype, - ).to( - TEST_DEVICE - ) # 2D embeddings - labels = torch.randint(low=0, high=5, size=(180,)).to(TEST_DEVICE) - - loss = loss_func(embeddings, labels) - loss.backward() - correct_loss = original_loss_func(embeddings, labels) - rtol = 1e-2 if dtype == torch.float16 else 1e-5 - self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + + for use_autocast in [True, False]: + + if use_autocast: + cm = torch.cuda.amp.autocast() + else: + cm = nullcontext() + + for dtype in TEST_DTYPES: + alpha = 1 if dtype == torch.float16 else 32 + loss_func = ProxyAnchorLoss( + num_classes, embedding_size, margin=margin, alpha=alpha + ).to(TEST_DEVICE) + original_loss_func = OriginalImplementationProxyAnchor( + num_classes, embedding_size, mrg=margin, alpha=alpha + ).to(TEST_DEVICE) + + if not use_autocast: + original_loss_func.proxies.data = ( + original_loss_func.proxies.data.type(dtype) + ) + loss_func.proxies = original_loss_func.proxies + + embedding_angles = list(range(0, 180)) + embeddings = torch.tensor( + [c_f.angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=torch.float32, + ).to( + TEST_DEVICE + ) # 2D embeddings + + if not use_autocast: + embeddings = embeddings.type(dtype) + labels = torch.randint(low=0, high=5, size=(180,)).to(TEST_DEVICE) + + with cm: + loss = loss_func(embeddings, labels) + + loss.backward() + correct_loss = original_loss_func(embeddings, labels) + rtol = 1e-2 if dtype == torch.float16 or use_autocast else 1e-5 + + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) if __name__ == "__main__": diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index 2ee973e1..4da3342d 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -226,7 +226,7 @@ def fn2(x, y): device=TEST_DEVICE, ) - label_counts, num_k = accuracy_calculator.get_label_match_counts( + label_counts = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, @@ -268,7 +268,7 @@ def custom_label_comparison_fn(x, y): ) for comparison_fn in [equality2D, custom_label_comparison_fn]: - label_counts, num_k = accuracy_calculator.get_label_match_counts( + label_counts = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, @@ -338,7 +338,7 @@ def custom_label_comparison_fn(x, y): def test_get_lone_query_labels(self): query_labels = torch.tensor([0, 1, 2, 3, 4, 5, 6], device=TEST_DEVICE) reference_labels = torch.tensor([0, 0, 0, 1, 2, 2, 3, 4, 5], device=TEST_DEVICE) - label_counts, _ = accuracy_calculator.get_label_match_counts( + label_counts = accuracy_calculator.get_label_match_counts( query_labels, reference_labels, accuracy_calculator.EQUALITY, @@ -630,3 +630,11 @@ def label_comparison_fn(x, y): ) for k in correct: self.assertTrue(isclose(acc[k], correct[k])) + + +class TestCalculateAccuraciesValidK(unittest.TestCase): + def test_valid_k(self): + for k in [-1, 0, 1.5, "max"]: + self.assertRaises( + ValueError, lambda: accuracy_calculator.AccuracyCalculator(k=k) + ) diff --git a/tests/utils/test_calculate_accuracies_large_k.py b/tests/utils/test_calculate_accuracies_large_k.py new file mode 100644 index 00000000..72a73a3d --- /dev/null +++ b/tests/utils/test_calculate_accuracies_large_k.py @@ -0,0 +1,113 @@ +import unittest + +import numpy as np +import torch + +from pytorch_metric_learning.utils import accuracy_calculator, stat_utils + +### FROM https://gist.github.com/VChristlein/fd55016f8d1b38e95011a025cbff9ccc +### and https://github.com/KevinMusgrave/pytorch-metric-learning/issues/290 + + +class TestCalculateAccuraciesLargeK(unittest.TestCase): + def test_accuracy_calculator_large_k(self): + for ecfss in [False, True]: + for max_k in [None, "max_bin_count"]: + for num_embeddings in [1000, 2100]: + # make random features + encs = np.random.rand(num_embeddings, 5).astype(np.float32) + # and random labels of 100 classes + labels = np.zeros((num_embeddings // 100, 100), dtype=np.int32) + for i in range(10): + labels[i] = np.arange(100) + labels = labels.ravel() + + correct_p1, correct_map, correct_mapr = self.evaluate( + encs, labels, max_k, ecfss + ) + + # use Musgrave's library + if max_k is None: + k = len(encs) - 1 if ecfss else len(encs) + accs = [ + accuracy_calculator.AccuracyCalculator(), + accuracy_calculator.AccuracyCalculator(k=k), + ] + elif max_k == "max_bin_count": + accs = [ + accuracy_calculator.AccuracyCalculator(k="max_bin_count") + ] + + for acc in accs: + d = acc.get_accuracy( + encs, + encs, + labels, + labels, + ecfss, + include=( + "mean_average_precision", + "mean_average_precision_at_r", + "precision_at_1", + ), + ) + + self.assertTrue(np.isclose(correct_p1, d["precision_at_1"])) + self.assertTrue( + np.isclose(correct_map, d["mean_average_precision"]) + ) + self.assertTrue( + np.isclose(correct_mapr, d["mean_average_precision_at_r"]) + ) + + def evaluate(self, encs, labels, max_k=None, ecfss=False): + """ + evaluate encodings assuming using associated labels + parameters: + encs: TxD encoding matrix + labels: array/list of T labels + """ + + # let's use Musgrave's knn + torch_encs = torch.from_numpy(encs) + k = len(encs) - 1 if ecfss else len(encs) + all_indices, _ = stat_utils.get_knn(torch_encs, torch_encs, k, ecfss) + if max_k is None: + max_k = k + indices = all_indices + elif max_k == "max_bin_count": + max_k = int(max(np.bincount(labels))) - int(ecfss) + indices, _ = stat_utils.get_knn(torch_encs, torch_encs, max_k, ecfss) + + # let's use the most simple mAP implementation + # of course this can be computed much faster using cumsum, etc. + n_encs = len(encs) + mAP = [] + mAP_at_r = [] + correct = 0 + for r in range(n_encs): + precisions = [] + rel = 0 + # indices doesn't contain the query index itself anymore, so no correction w. -1 necessary + all_rel = np.count_nonzero(labels[all_indices[r]] == labels[r]) + prec_at_r = [] + for k in range(max_k): + if labels[indices[r, k]] == labels[r]: + rel += 1 + precisions.append(rel / float(k + 1)) + if k == 0: + correct += 1 + + # mAP@R + if k < all_rel: + prec_at_r.append(rel / float(k + 1)) + + avg_precision = np.mean(precisions) if len(precisions) > 0 else 0 + mAP.append(avg_precision) + # mAP@R + avg_prec_at_r = np.sum(prec_at_r) / all_rel if all_rel > 0 else 0 + mAP_at_r.append(avg_prec_at_r) + + mAP = np.mean(mAP) + mAP_at_r = np.mean(mAP_at_r) + return float(correct) / n_encs, mAP, mAP_at_r diff --git a/tests/utils/test_loss_and_miner_utils.py b/tests/utils/test_loss_and_miner_utils.py index 700ed3bc..3b2674cd 100644 --- a/tests/utils/test_loss_and_miner_utils.py +++ b/tests/utils/test_loss_and_miner_utils.py @@ -98,10 +98,29 @@ def test_convert_to_triplets(self): a2 = torch.LongTensor([0, 4, 5, 6]) triplets = lmu.convert_to_triplets((a1, p, a2, n), labels=torch.arange(7)) self.assertTrue( - triplets - == [torch.LongTensor([0]), torch.LongTensor([4]), torch.LongTensor([5])] + triplets == (torch.tensor([0]), torch.tensor([4]), torch.tensor([5])) ) + a1 = torch.LongTensor([0, 1, 0, 2]) + p = torch.LongTensor([5, 6, 7, 8]) + a2 = torch.LongTensor([0, 1, 2, 0]) + n = torch.LongTensor([9, 10, 11, 12]) + triplets = lmu.convert_to_triplets((a1, p, a2, n), labels=torch.arange(13)) + triplets = torch.stack(triplets, dim=1) + found_set = set() + for t in triplets: + found_set.add(tuple(t.cpu().numpy())) + correct_triplets = { + (0, 5, 9), + (0, 5, 12), + (0, 7, 9), + (0, 7, 12), + (1, 6, 10), + (2, 8, 11), + } + + self.assertTrue(found_set == correct_triplets) + def test_convert_to_weights(self): a = torch.LongTensor([0, 1, 2, 3]).to(TEST_DEVICE) p = torch.LongTensor([4, 4, 4, 4]).to(TEST_DEVICE)