From 2dc4bbafe36bec7cfb25c514dfe7d5841ac59cc3 Mon Sep 17 00:00:00 2001 From: huangjg Date: Sun, 24 Dec 2023 22:12:59 +0800 Subject: [PATCH 1/3] update the framework of scores --- torchcp/classification/predictors/split.py | 5 +--- torchcp/classification/scores/aps.py | 20 +++++++++------- torchcp/classification/scores/base.py | 10 +------- torchcp/classification/scores/margin.py | 20 +++++----------- torchcp/classification/scores/raps.py | 24 ++++++------------- torchcp/classification/scores/saps.py | 28 +++++++--------------- torchcp/classification/scores/thr.py | 24 ++++++++++++------- 7 files changed, 50 insertions(+), 81 deletions(-) diff --git a/torchcp/classification/predictors/split.py b/torchcp/classification/predictors/split.py index 2c63bbb..98b00f9 100644 --- a/torchcp/classification/predictors/split.py +++ b/torchcp/classification/predictors/split.py @@ -39,9 +39,6 @@ def calculate_threshold(self, logits, labels, alpha): logits = logits.to(self._device) labels = labels.to(self._device) scores = self.score_function(logits,labels) - # scores = logits.new_zeros(logits.shape[0]) - # for index, (x, y) in enumerate(zip(logits, labels)): - # scores[index] = self.score_function(x, y) self.q_hat = self._calculate_conformal_value(scores, alpha) def _calculate_conformal_value(self, scores, alpha): @@ -90,7 +87,7 @@ def predict_with_logits(self, logits, q_hat=None): :return: prediction sets """ - scores = self.score_function.predict(logits).to(self._device) + scores = self.score_function(logits).to(self._device) if q_hat is None: S = self._generate_prediction_set(scores, self.q_hat) else: diff --git a/torchcp/classification/scores/aps.py b/torchcp/classification/scores/aps.py index eef2ac1..e39d433 100644 --- a/torchcp/classification/scores/aps.py +++ b/torchcp/classification/scores/aps.py @@ -22,20 +22,21 @@ class APS(BaseScore): def __init__(self): super(APS, self).__init__() - def __call__(self, logits, y): + def __call__(self, logits, y=None): assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." if len(logits) == 1: logits = logits.unsqueeze(0) probs = torch.softmax(logits, dim=-1) - indices, ordered, cumsum = self._sort_sum(probs) - return self.__compute_score(indices, y, cumsum, ordered) + if y is None: + return self._calculate_all_label(probs) + else: + return self._calculate_single_label(probs, y) - def predict(self, logits): - probs = torch.softmax(logits, dim=-1) - I, ordered, cumsum = self._sort_sum(probs) - U = torch.rand(probs.shape, device=logits.device) + def _calculate_all_label(self, probs): + indices, ordered, cumsum = self._sort_sum(probs) + U = torch.rand(probs.shape, device=probs.device) ordered_scores = cumsum - ordered * U - _, sorted_indices = torch.sort(I, descending=False, dim=-1) + _, sorted_indices = torch.sort(indices, descending=False, dim=-1) scores = ordered_scores.gather(dim=-1, index=sorted_indices) return scores @@ -47,7 +48,8 @@ def _sort_sum(self, probs): cumsum = torch.cumsum(ordered, dim=-1) return indices, ordered, cumsum - def __compute_score(self, indices, y, cumsum, ordered): + def _calculate_single_label(self, probs, y): + indices, ordered, cumsum = self._sort_sum(probs) U = torch.rand(indices.shape[0], device = indices.device) idx = torch.where(indices == y.view(-1, 1)) scores_first_rank = U * cumsum[idx] diff --git a/torchcp/classification/scores/base.py b/torchcp/classification/scores/base.py index c951575..753db8f 100644 --- a/torchcp/classification/scores/base.py +++ b/torchcp/classification/scores/base.py @@ -19,18 +19,10 @@ def __init__(self) -> None: pass @abstractmethod - def __call__(self, logits, y): + def __call__(self, logits, y=None): """Virtual method to compute scores for a data pair (x,y). :param logits: the logits for inputs. :y : the labels. """ raise NotImplementedError - - @abstractmethod - def predict(self, logits): - """Virtual method to compute scores of all labels for an input. - - :param logits: the logits for an input. - """ - raise NotImplementedError diff --git a/torchcp/classification/scores/margin.py b/torchcp/classification/scores/margin.py index ec2625b..3453934 100644 --- a/torchcp/classification/scores/margin.py +++ b/torchcp/classification/scores/margin.py @@ -6,10 +6,10 @@ # import torch -from torchcp.classification.scores.base import BaseScore +from torchcp.classification.scores.aps import APS -class Margin(BaseScore): +class Margin(APS): def __init__(self, ) -> None: """ @@ -17,24 +17,16 @@ def __init__(self, ) -> None: """ super().__init__() - - def __call__(self, logits, y): - assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." - if len(logits) == 1: - logits = logits.unsqueeze(0) - probs = torch.softmax(logits, dim=-1) - - row_indices = torch.arange(probs.size(0), device = logits.device) + + def _calculate_single_label(self, probs, y): + row_indices = torch.arange(probs.size(0), device = probs.device) target_prob = probs[row_indices, y].clone() probs[row_indices, y] = -1 second_highest_prob = torch.max(probs, dim=-1).values return second_highest_prob - target_prob - def predict(self, logits): - assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." - if len(logits) == 1: - logits = logits.unsqueeze(0) + def _calculate_all_label(self, logits): probs = torch.softmax(logits, dim=-1) temp_probs = probs.unsqueeze(1).repeat(1, probs.shape[1], 1) indices = torch.arange(probs.shape[1]).to(logits.device) diff --git a/torchcp/classification/scores/raps.py b/torchcp/classification/scores/raps.py index ab6427d..e2d5db2 100644 --- a/torchcp/classification/scores/raps.py +++ b/torchcp/classification/scores/raps.py @@ -35,30 +35,20 @@ def __init__(self, penalty, kreg=0): self.__penalty = penalty self.__kreg = kreg - def __call__(self, logits, y): - assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." - if len(logits) == 1: - logits = logits.unsqueeze(0) - probs = torch.softmax(logits, dim=-1) - # sorting probabilities - indices, ordered, cumsum = self._sort_sum(probs) - return self.__compute_score(indices, y, cumsum, ordered) - - - def predict(self, logits): - probs = torch.softmax(logits, dim=-1) - I, ordered, cumsum = self._sort_sum(probs) - U = torch.rand(probs.shape, device = logits.device) - reg = torch.maximum(self.__penalty * (torch.arange(1, probs.shape[-1] + 1, device=logits.device) - self.__kreg),torch.tensor(0).to(logits.device)) + def _calculate_all_label(self, probs): + indices, ordered, cumsum = self._sort_sum(probs) + U = torch.rand(probs.shape, device = probs.device) + reg = torch.maximum(self.__penalty * (torch.arange(1, probs.shape[-1] + 1, device=probs.device) - self.__kreg),torch.tensor(0, device=probs.device)) ordered_scores = cumsum - ordered * U + reg - _, sorted_indices = torch.sort(I, descending=False, dim=-1) + _, sorted_indices = torch.sort(indices, descending=False, dim=-1) scores = ordered_scores.gather(dim=-1, index=sorted_indices) return scores - def __compute_score(self, indices, y, cumsum, ordered): + def _calculate_single_label(self, probs, y): + indices, ordered, cumsum = self._sort_sum(probs) U = torch.rand(indices.shape[0], device = indices.device) idx = torch.where(indices == y.view(-1, 1)) reg = torch.maximum(self.__penalty * (idx[1] + 1 - self.__kreg), torch.tensor(0).to(indices.device)) diff --git a/torchcp/classification/scores/saps.py b/torchcp/classification/scores/saps.py index 6760f8b..5ecb629 100644 --- a/torchcp/classification/scores/saps.py +++ b/torchcp/classification/scores/saps.py @@ -25,32 +25,20 @@ def __init__(self, weight): if weight <= 0: raise ValueError("The parameter 'weight' must be a positive value.") self.__weight = weight - - def __call__(self, logits, y): - assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." - if len(logits) == 1: - logits = logits.unsqueeze(0) - probs = torch.softmax(logits, dim=-1) - # sorting probabilities + + + def _calculate_all_label(self, probs): indices, ordered, cumsum = self._sort_sum(probs) - return self.__compute_score(indices, y, cumsum, ordered) - - def predict(self, logits): - probs = torch.softmax(logits, dim=-1) - I, ordered, _ = self._sort_sum(probs) - if len(logits.shape) == 1: - ordered[1:] = self.__weight - else: - ordered[...,1:] = self.__weight + ordered[:,1:] = self.__weight cumsum = torch.cumsum(ordered, dim=-1) - U = torch.rand(probs.shape, device=logits.device) + U = torch.rand(probs.shape, device=probs.device) ordered_scores = cumsum - ordered * U - _, sorted_indices = torch.sort(I, descending=False, dim=-1) + _, sorted_indices = torch.sort(indices, descending=False, dim=-1) scores = ordered_scores.gather(dim=-1, index=sorted_indices) return scores - def __compute_score(self, indices, y, cumsum, ordered): - + def _calculate_single_label(self, probs, y): + indices, ordered, cumsum = self._sort_sum(probs) U = torch.rand(indices.shape[0], device = indices.device) idx = torch.where(indices == y.view(-1, 1)) scores_first_rank = U * cumsum[idx] diff --git a/torchcp/classification/scores/thr.py b/torchcp/classification/scores/thr.py index a276767..ad29d48 100644 --- a/torchcp/classification/scores/thr.py +++ b/torchcp/classification/scores/thr.py @@ -24,20 +24,28 @@ def __init__(self, score_type="softmax") -> None: if score_type == "Identity": self.transform = lambda x: x elif score_type == "softmax": - self.transform = lambda x: torch.softmax(x, dim=len(x.shape) - 1) + self.transform = lambda x: torch.softmax(x, dim=- 1) elif score_type == "log_softmax": - self.transform = lambda x: torch.log_softmax(x, dim=len(x.shape) - 1) + self.transform = lambda x: torch.log_softmax(x, dim=-1) elif score_type == "log": - self.transform = lambda x: torch.log(x, dim=len(x.shape) - 1) + self.transform = lambda x: torch.log(x, dim=-1) else: raise NotImplementedError - def __call__(self, logits, y): + def __call__(self, logits, y=None): assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." if len(logits) == 1: logits = logits.unsqueeze(0) - return 1 - torch.softmax(logits, dim=-1)[torch.arange(y.shape[0], device = logits.device), y] - + temp_values = self.transform(logits) + if y is None: + return self.__calculate_all_label(temp_values) + else: + return self.__calculate_single_label(temp_values, y) + + def __calculate_single_label(self, temp_values, y): + return 1 - temp_values[torch.arange(y.shape[0], device = temp_values.device), y] + + def __calculate_all_label(self, temp_values): + return 1 - temp_values - def predict(self, logits): - return 1 - torch.softmax(logits, dim=-1) + From a9606facb5892d77db7c5ae25f81766c17b13296 Mon Sep 17 00:00:00 2001 From: huangjg Date: Sun, 24 Dec 2023 22:17:36 +0800 Subject: [PATCH 2/3] update Margin --- torchcp/classification/scores/aps.py | 2 +- torchcp/classification/scores/margin.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/torchcp/classification/scores/aps.py b/torchcp/classification/scores/aps.py index e39d433..7e845bc 100644 --- a/torchcp/classification/scores/aps.py +++ b/torchcp/classification/scores/aps.py @@ -20,7 +20,7 @@ class APS(BaseScore): """ def __init__(self): - super(APS, self).__init__() + pass def __call__(self, logits, y=None): assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." diff --git a/torchcp/classification/scores/margin.py b/torchcp/classification/scores/margin.py index 3453934..75ad484 100644 --- a/torchcp/classification/scores/margin.py +++ b/torchcp/classification/scores/margin.py @@ -12,10 +12,7 @@ class Margin(APS): def __init__(self, ) -> None: - """ - param score_type: either "softmax" "Identity", "log_softmax" or "log". Default: "softmax". A transformation for logits. - """ - super().__init__() + pass def _calculate_single_label(self, probs, y): @@ -26,10 +23,9 @@ def _calculate_single_label(self, probs, y): return second_highest_prob - target_prob - def _calculate_all_label(self, logits): - probs = torch.softmax(logits, dim=-1) + def _calculate_all_label(self, probs): temp_probs = probs.unsqueeze(1).repeat(1, probs.shape[1], 1) - indices = torch.arange(probs.shape[1]).to(logits.device) + indices = torch.arange(probs.shape[1]).to(probs.device) temp_probs[None, indices, indices] = torch.finfo(torch.float32).min scores = torch.max(temp_probs, dim=-1).values - probs return scores From df02ac3a183c022239a3d14b282e0f41409c97ef Mon Sep 17 00:00:00 2001 From: Hongxin Wei Date: Sun, 24 Dec 2023 22:34:45 +0800 Subject: [PATCH 3/3] update code --- .github/workflows/deploy.yml | 50 +++++++++---------- CONTRIBUTING.md | 6 ++- README.md | 31 +++++++----- docs/source/conf.py | 6 +-- docs/source/index.rst | 6 +-- examples/clip/clip.py | 12 +++-- examples/clip/model.py | 10 ++-- examples/clip/simple_tokenizer.py | 24 +++++---- examples/covariate_shift.py | 12 +++-- examples/imagenet_example_logits.py | 7 +-- torchcp/classification/predictors/base.py | 1 - .../classification/predictors/classwise.py | 2 +- torchcp/classification/predictors/cluster.py | 2 - torchcp/classification/predictors/split.py | 17 ++++--- torchcp/classification/predictors/weight.py | 2 +- torchcp/classification/scores/aps.py | 18 +++---- torchcp/classification/scores/base.py | 4 +- torchcp/classification/scores/margin.py | 14 ++---- torchcp/classification/scores/raps.py | 27 +++++----- torchcp/classification/scores/saps.py | 18 +++---- torchcp/classification/scores/thr.py | 16 +++--- torchcp/regression/predictors/cqr.py | 2 +- torchcp/regression/predictors/split.py | 2 +- 23 files changed, 146 insertions(+), 143 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 8623870..87f140c 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -2,7 +2,7 @@ name: Publish Python 🐍 distributions 📦 to PyPI on: -# automatically running github actions when push a tag + # automatically running github actions when push a tag push: tags: - '*' @@ -20,27 +20,27 @@ jobs: id-token: write contents: read steps: - - uses: actions/checkout@master - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: '3.10' - - name: Install pypa/setuptools - run: >- - python -m - pip install wheel - pip install readme_renderer[md] - - name: Build a binary wheel - run: >- - python setup.py sdist bdist_wheel -# - name: Publish distribution 📦 to TestPyPI -# uses: pypa/gh-action-pypi-publish@release/v1 -# with: -# user: __token__ -# password: ${{ secrets.jianguo_test_pypi_password }} -# repository_url: https://test.pypi.org/legacy/ - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.jianguo_pypi_password }} + - uses: actions/checkout@master + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Install pypa/setuptools + run: >- + python -m + pip install wheel + pip install readme_renderer[md] + - name: Build a binary wheel + run: >- + python setup.py sdist bdist_wheel + # - name: Publish distribution 📦 to TestPyPI + # uses: pypa/gh-action-pypi-publish@release/v1 + # with: + # user: __token__ + # password: ${{ secrets.jianguo_test_pypi_password }} + # repository_url: https://test.pypi.org/legacy/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.jianguo_pypi_password }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3ac8759..f79e576 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,9 +4,11 @@ Thank you considering contributing to TorchCP! This document provides brief guidelines for potential contributors. -Please use pull requests for new features, bug fixes, new examples, etc. If you work on something with significant efforts, please mention it in early stage using issues. +Please use pull requests for new features, bug fixes, new examples, etc. If you work on something with significant +efforts, please mention it in early stage using issues. -We ask that you follow the `PEP8` coding style in your pull requests, [`flake8`](http://flake8.pycqa.org/) is used in continuous integration to enforce this. +We ask that you follow the `PEP8` coding style in your pull requests, [`flake8`](http://flake8.pycqa.org/) is used in +continuous integration to enforce this. --- diff --git a/README.md b/README.md index 12121a4..a636f32 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,15 @@ -TorchCP is a Python toolbox for conformal prediction research on deep learning models, using PyTorch. Specifically, this toolbox has implemented some representative methods (including posthoc and training methods) for -classification and regression tasks. We build the framework of TorchCP based on [`AdverTorch`](https://github.com/BorealisAI/advertorch/tree/master). This codebase is still under construction. Comments, issues, contributions, and collaborations are all welcomed! - - +TorchCP is a Python toolbox for conformal prediction research on deep learning models, using PyTorch. Specifically, this +toolbox has implemented some representative methods (including posthoc and training methods) for +classification and regression tasks. We build the framework of TorchCP based +on [`AdverTorch`](https://github.com/BorealisAI/advertorch/tree/master). This codebase is still under construction. +Comments, issues, contributions, and collaborations are all welcomed! # Overview + TorchCP has implemented the following methods: + ## Classification + | Year | Title | Venue | Code Link | |------|--------------------------------------------------------------------------------------------------------------------------------------------------|---------|-----------------------------------------------------------------------------------| | 2023 | [**Class-Conditional Conformal Prediction with Many Classes**](https://arxiv.org/abs/2306.09335) | NeurIPS | [Link](https://github.com/tiffanyding/class-conditional-conformal) | @@ -18,15 +22,15 @@ TorchCP has implemented the following methods: | 2013 | [**Applications of Class-Conditional Conformal Predictor in Multi-Class Classification**](https://ieeexplore.ieee.org/document/6784618) | ICMLA | | ## Regression + | Year | Title | Venue | Code Link | |------|------------------------------------------------------------------------------------------------------------------------------------------------|---------|------------------------------------------------------| | 2021 | [**Adaptive Conformal Inference Under Distribution Shift**](https://arxiv.org/abs/2106.00170) | NeurIPS | [Link](https://github.com/isgibbs/AdaptiveConformal) | | 2019 | [**Conformalized Quantile Regression**](https://proceedings.neurips.cc/paper_files/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf) | NeurIPS | [Link](https://github.com/yromano/cqr) | | 2016 | [**Distribution-Free Predictive Inference For Regression**](https://arxiv.org/abs/1604.04173) | JASA | [Link](https://github.com/ryantibs/conformal) | - - ## TODO + TorchCP is still under active development. We will add the following features/items down the road: | Year | Title | Venue | Code Link | @@ -37,17 +41,16 @@ TorchCP is still under active development. We will add the following features/it | 2022 | [**Conformal Prediction Sets with Limited False Positives**](https://arxiv.org/abs/2202.07650) | ICML | [Link](https://github.com/ajfisch/conformal-fp) | | 2021 | [**Optimized conformal classification using gradient descent approximation**](https://arxiv.org/abs/2105.11255) | Arxiv | | - - - - ## Installation TorchCP is developed with Python 3.9 and PyTorch 2.0.1. To install TorchCP, simply run + ``` pip install torchcp ``` + To install from TestPyPI server, run + ``` pip install --index-url https://test.pypi.org/simple/ --no-deps torchcp ``` @@ -55,6 +58,7 @@ pip install --index-url https://test.pypi.org/simple/ --no-deps torchcp ## Examples Here, we provide a simple example for a classification task, with THR score and SplitPredictor. + ```python from torchcp.classification.scores import THR from torchcp.classification.predictors import SplitPredictor @@ -88,19 +92,21 @@ result_dict = predictor.evaluate(test_dataloader) print(result_dict["Coverage_rate"], result_dict["Average_size"]) ``` + You may find more tutorials in [`examples`](https://github.com/ml-stat-Sustech/TorchCP/tree/master/examples) folder. ## Documentation The documentation webpage is on readthedocs https://torchcp.readthedocs.io/en/latest/index.html. - ## License + This project is licensed under the LGPL. The terms and conditions can be found in the LICENSE and LICENSE.GPL files. ## Citation -We will release the technical report of TorchCP recently. If you find our repository useful for your research, please consider citing our paper: +We will release the technical report of TorchCP recently. If you find our repository useful for your research, please +consider citing our paper: ``` @article{huang2023conformal, @@ -110,6 +116,7 @@ We will release the technical report of TorchCP recently. If you find our reposi year={2023} } ``` + ## Contributors * [Hongxin Wei](https://hongxin001.github.io/) diff --git a/docs/source/conf.py b/docs/source/conf.py index 5b8ca82..5036df6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,9 +7,11 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os import sys + sys.path.insert(0, os.path.abspath('../../')) from unittest.mock import Mock # noqa: F401, E402 + # from sphinx.ext.autodoc.importer import _MockObject as Mock Mock.Module = object sys.modules['torch'] = Mock() @@ -49,8 +51,6 @@ with open(os.path.join(os.path.abspath('../../'), 'torchcp/VERSION')) as f: version = f.read().strip() - - # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -78,7 +78,6 @@ # The master toctree document. master_doc = 'index' - # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output @@ -88,7 +87,6 @@ html_theme = 'sphinx_rtd_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] - # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] diff --git a/docs/source/index.rst b/docs/source/index.rst index 3a6e7cc..e526e74 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,7 +1,7 @@ .. TorchCP documentation master file, created by - sphinx-quickstart on Fri Dec 22 16:28:31 2023. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. +sphinx-quickstart on Fri Dec 22 16:28:31 2023. +You can adapt this file completely to your liking, but it should at least +contain the root `toctree` directive. Welcome to TorchCP =================================== diff --git a/examples/clip/clip.py b/examples/clip/clip.py index f7a5da5..e025e61 100644 --- a/examples/clip/clip.py +++ b/examples/clip/clip.py @@ -15,15 +15,14 @@ try: from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC - if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): warnings.warn("PyTorch version 1.7.1 or higher is recommended") - __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() @@ -57,7 +56,8 @@ def _download(url: str, root: str): warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, + unit_divisor=1024) as loop: while True: buffer = source.read(8192) if not buffer: @@ -91,7 +91,8 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, download_root: str = None): """Load a CLIP model Parameters @@ -202,7 +203,8 @@ def patch_float(module): return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[ + torch.IntTensor, torch.LongTensor]: """ Returns the tokenized representation of given input string(s) diff --git a/examples/clip/model.py b/examples/clip/model.py index 232b779..729b8f2 100644 --- a/examples/clip/model.py +++ b/examples/clip/model.py @@ -224,7 +224,9 @@ def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) @@ -401,12 +403,14 @@ def build_model(state_dict: dict): if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_resolution = vision_patch_size * grid_size else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in + [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) diff --git a/examples/clip/simple_tokenizer.py b/examples/clip/simple_tokenizer.py index 0a66286..9a7a59e 100644 --- a/examples/clip/simple_tokenizer.py +++ b/examples/clip/simple_tokenizer.py @@ -23,13 +23,13 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) cs = bs[:] n = 0 - for b in range(2**8): + for b in range(2 ** 8): if b not in bs: bs.append(b) - cs.append(2**8+n) + cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -64,10 +64,10 @@ def __init__(self, bpe_path: str = default_bpe()): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] + merges = merges[1:49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] + vocab = vocab + [v + '' for v in vocab] for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) @@ -75,19 +75,21 @@ def __init__(self, bpe_path: str = default_bpe()): self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) def bpe(self, token): if token in self.cache: return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) + word = tuple(token[:-1]) + (token[-1] + '',) pairs = get_pairs(word) if not pairs: - return token+'' + return token + '' while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -102,8 +104,8 @@ def bpe(self, token): new_word.extend(word[i:]) break - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) i += 2 else: new_word.append(word[i]) diff --git a/examples/covariate_shift.py b/examples/covariate_shift.py index 1ffe312..6ce746c 100644 --- a/examples/covariate_shift.py +++ b/examples/covariate_shift.py @@ -14,7 +14,7 @@ import torch.nn as nn from dataset import build_dataset -from torchcp.classification.predictors import WeightedPredictor,SplitPredictor +from torchcp.classification.predictors import WeightedPredictor, SplitPredictor from torchcp.classification.scores import THR from torchcp.utils import fix_randomness @@ -67,7 +67,8 @@ def forward(self, x_batch): score_function = THR() alpha = 0.1 - print(f"Experiment--Source Data : imagenet, Target Data : imagenetv2, Model : {model_name}, Score : THR, Predictor : Standard, Alpha : {alpha}") + print( + f"Experiment--Source Data : imagenet, Target Data : imagenetv2, Model : {model_name}, Score : THR, Predictor : Standard, Alpha : {alpha}") predictors = SplitPredictor(score_function, model) predictors.calibrate(cal_data_loader, alpha) @@ -78,7 +79,10 @@ def forward(self, x_batch): ################################## # Invalid prediction sets ################################## - print(f"Experiment--Source Data : imagenet, Target Data : imagenetv2, Model : {model_name}, Score : THR, Predictor : WeightedPredictor, Alpha : {alpha}") + print( + f"Experiment--Source Data : imagenet, Target Data : imagenetv2, Model : {model_name}, Score : THR, Predictor : WeightedPredictor, Alpha : {alpha}") + + class ImageEncoder(nn.Module): def __init__(self, model): super(ImageEncoder, self).__init__() @@ -89,9 +93,9 @@ def forward(self, x_bs): image_features /= image_features.norm(dim=-1, keepdim=True) return image_features + image_encoder = ImageEncoder(clip.eval().to(model_device)) - predictor = WeightedPredictor(score_function, model, image_encoder) # If you have prepared a domain classifier, you can pass it to WeightedPredictor. # Otherwise, you can use following codes, it will automatically train a domain classifier. diff --git a/examples/imagenet_example_logits.py b/examples/imagenet_example_logits.py index d475d66..ab92016 100644 --- a/examples/imagenet_example_logits.py +++ b/examples/imagenet_example_logits.py @@ -81,7 +81,8 @@ num_classes = 1000 alpha = args.alpha - print(f"Experiment--Data : ImageNet, Model : {model_name}, Score : {args.score}, Predictor : {args.predictor}, Alpha : {alpha}") + print( + f"Experiment--Data : ImageNet, Model : {model_name}, Score : {args.score}, Predictor : {args.predictor}, Alpha : {alpha}") if args.score == "THR": score_function = THR() elif args.score == "APS": @@ -94,7 +95,7 @@ score_function = Margin() else: raise NotImplementedError - + if args.predictor == "Standard": predictor = SplitPredictor(score_function, model=None) elif args.predictor == "ClassWise": @@ -109,7 +110,7 @@ # for index, ele in enumerate(test_logits): # prediction_set = predictor.predict_with_logits(ele) # prediction_sets.append(prediction_set) - + prediction_sets = predictor.predict_with_logits(test_logits) metrics = Metrics() diff --git a/torchcp/classification/predictors/base.py b/torchcp/classification/predictors/base.py index 558b9c9..ae954c0 100644 --- a/torchcp/classification/predictors/base.py +++ b/torchcp/classification/predictors/base.py @@ -64,4 +64,3 @@ def _generate_prediction_set(self, scores, q_hat): return torch.argwhere(scores < q_hat).reshape(-1).tolist() else: return [torch.argwhere(scores[i] < q_hat).reshape(-1).tolist() for i in range(scores.shape[0])] - diff --git a/torchcp/classification/predictors/classwise.py b/torchcp/classification/predictors/classwise.py index c574e3e..f2b9026 100644 --- a/torchcp/classification/predictors/classwise.py +++ b/torchcp/classification/predictors/classwise.py @@ -32,7 +32,7 @@ def calculate_threshold(self, logits, labels, alpha): x_cal_tmp = logits[labels == label] y_cal_tmp = labels[labels == label] scores = logits.new_zeros(x_cal_tmp.shape[0]) - + scores = self.score_function(x_cal_tmp, y_cal_tmp) self.q_hat[label] = self._calculate_conformal_value(scores, alpha) diff --git a/torchcp/classification/predictors/cluster.py b/torchcp/classification/predictors/cluster.py index 93c3465..df6b884 100644 --- a/torchcp/classification/predictors/cluster.py +++ b/torchcp/classification/predictors/cluster.py @@ -43,8 +43,6 @@ def calculate_threshold(self, logits, labels, alpha): labels = labels.to(self._device) self.num_classes = logits.shape[1] scores = self.score_function(logits, labels) - - alpha = torch.tensor(alpha, device=self._device) classes_statistics = torch.tensor([torch.sum(labels == k).item() for k in range(self.num_classes)], diff --git a/torchcp/classification/predictors/split.py b/torchcp/classification/predictors/split.py index 98b00f9..2ecd529 100644 --- a/torchcp/classification/predictors/split.py +++ b/torchcp/classification/predictors/split.py @@ -15,7 +15,6 @@ class SplitPredictor(BasePredictor): def __init__(self, score_function, model=None, temperature=1): super().__init__(score_function, model, temperature) - ############################# # The calibration process ############################ @@ -38,9 +37,9 @@ def calculate_threshold(self, logits, labels, alpha): raise ValueError("Significance level 'alpha' must be in (0,1).") logits = logits.to(self._device) labels = labels.to(self._device) - scores = self.score_function(logits,labels) + scores = self.score_function(logits, labels) self.q_hat = self._calculate_conformal_value(scores, alpha) - + def _calculate_conformal_value(self, scores, alpha): """ Calculate the 1-alpha quantile of scores. @@ -51,14 +50,16 @@ def _calculate_conformal_value(self, scores, alpha): :return: the threshold which is use to construct prediction sets. """ if len(scores) == 0: - warnings.warn("The number of scores is 0, which is a invalid scores. To avoid program crash, the threshold is set as torch.inf.") + warnings.warn( + "The number of scores is 0, which is a invalid scores. To avoid program crash, the threshold is set as torch.inf.") return torch.inf qunatile_value = math.ceil(scores.shape[0] + 1) * (1 - alpha) / scores.shape[0] - + if qunatile_value > 1: - warnings.warn("The value of quantile exceeds 1. It should be a value in (0,1). To avoid program crash, the threshold is set as torch.inf.") + warnings.warn( + "The value of quantile exceeds 1. It should be a value in (0,1). To avoid program crash, the threshold is set as torch.inf.") return torch.inf - + return torch.quantile(scores, qunatile_value).to(self._device) ############################# @@ -74,7 +75,7 @@ def predict(self, x_batch): if self._model != None: x_batch = self._model(x_batch.to(self._device)).float() x_batch = self._logits_transformation(x_batch).detach() - sets = self.predict_with_logits(x_batch) + sets = self.predict_with_logits(x_batch) return sets def predict_with_logits(self, logits, q_hat=None): diff --git a/torchcp/classification/predictors/weight.py b/torchcp/classification/predictors/weight.py index a4289cf..7a3c0f6 100644 --- a/torchcp/classification/predictors/weight.py +++ b/torchcp/classification/predictors/weight.py @@ -26,7 +26,7 @@ def __init__(self, score_function, model, image_encoder, domain_classifier=None, self.alpha = None # Domain Classifier self.domain_classifier = domain_classifier - + def calibrate(self, cal_dataloader, alpha): logits_list = [] labels_list = [] diff --git a/torchcp/classification/scores/aps.py b/torchcp/classification/scores/aps.py index 7e845bc..5f8b866 100644 --- a/torchcp/classification/scores/aps.py +++ b/torchcp/classification/scores/aps.py @@ -22,15 +22,15 @@ class APS(BaseScore): def __init__(self): pass - def __call__(self, logits, y=None): + def __call__(self, logits, label=None): assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." if len(logits) == 1: logits = logits.unsqueeze(0) probs = torch.softmax(logits, dim=-1) - if y is None: + if label is None: return self._calculate_all_label(probs) else: - return self._calculate_single_label(probs, y) + return self._calculate_single_label(probs, label) def _calculate_all_label(self, probs): indices, ordered, cumsum = self._sort_sum(probs) @@ -48,13 +48,11 @@ def _sort_sum(self, probs): cumsum = torch.cumsum(ordered, dim=-1) return indices, ordered, cumsum - def _calculate_single_label(self, probs, y): + def _calculate_single_label(self, probs, label): indices, ordered, cumsum = self._sort_sum(probs) - U = torch.rand(indices.shape[0], device = indices.device) - idx = torch.where(indices == y.view(-1, 1)) - scores_first_rank = U * cumsum[idx] + U = torch.rand(indices.shape[0], device=probs.device) + idx = torch.where(indices == label.view(-1, 1)) + scores_first_rank = U * cumsum[idx] idx_minus_one = (idx[0], idx[1] - 1) - scores_usual = U * ordered[idx] + cumsum[idx_minus_one] + scores_usual = U * ordered[idx] + cumsum[idx_minus_one] return torch.where(idx[1] == 0, scores_first_rank, scores_usual) - - diff --git a/torchcp/classification/scores/base.py b/torchcp/classification/scores/base.py index 753db8f..45170b0 100644 --- a/torchcp/classification/scores/base.py +++ b/torchcp/classification/scores/base.py @@ -19,10 +19,10 @@ def __init__(self) -> None: pass @abstractmethod - def __call__(self, logits, y=None): + def __call__(self, logits, labels=None): """Virtual method to compute scores for a data pair (x,y). :param logits: the logits for inputs. - :y : the labels. + :param labels : the labels. """ raise NotImplementedError diff --git a/torchcp/classification/scores/margin.py b/torchcp/classification/scores/margin.py index 75ad484..229d504 100644 --- a/torchcp/classification/scores/margin.py +++ b/torchcp/classification/scores/margin.py @@ -14,14 +14,12 @@ class Margin(APS): def __init__(self, ) -> None: pass - - def _calculate_single_label(self, probs, y): - row_indices = torch.arange(probs.size(0), device = probs.device) - target_prob = probs[row_indices, y].clone() - probs[row_indices, y] = -1 + def _calculate_single_label(self, probs, label): + row_indices = torch.arange(probs.size(0), device=probs.device) + target_prob = probs[row_indices, label].clone() + probs[row_indices, label] = -1 second_highest_prob = torch.max(probs, dim=-1).values return second_highest_prob - target_prob - def _calculate_all_label(self, probs): temp_probs = probs.unsqueeze(1).repeat(1, probs.shape[1], 1) @@ -29,7 +27,3 @@ def _calculate_all_label(self, probs): temp_probs[None, indices, indices] = torch.finfo(torch.float32).min scores = torch.max(temp_probs, dim=-1).values - probs return scores - - - - diff --git a/torchcp/classification/scores/raps.py b/torchcp/classification/scores/raps.py index e2d5db2..ee14501 100644 --- a/torchcp/classification/scores/raps.py +++ b/torchcp/classification/scores/raps.py @@ -35,25 +35,22 @@ def __init__(self, penalty, kreg=0): self.__penalty = penalty self.__kreg = kreg - def _calculate_all_label(self, probs): indices, ordered, cumsum = self._sort_sum(probs) - U = torch.rand(probs.shape, device = probs.device) - reg = torch.maximum(self.__penalty * (torch.arange(1, probs.shape[-1] + 1, device=probs.device) - self.__kreg),torch.tensor(0, device=probs.device)) + U = torch.rand(probs.shape, device=probs.device) + reg = torch.maximum(self.__penalty * (torch.arange(1, probs.shape[-1] + 1, device=probs.device) - self.__kreg), + torch.tensor(0, device=probs.device)) ordered_scores = cumsum - ordered * U + reg _, sorted_indices = torch.sort(indices, descending=False, dim=-1) scores = ordered_scores.gather(dim=-1, index=sorted_indices) return scores - - - - def _calculate_single_label(self, probs, y): - indices, ordered, cumsum = self._sort_sum(probs) - U = torch.rand(indices.shape[0], device = indices.device) - idx = torch.where(indices == y.view(-1, 1)) - reg = torch.maximum(self.__penalty * (idx[1] + 1 - self.__kreg), torch.tensor(0).to(indices.device)) - scores_first_rank = U * ordered[idx] + reg - idx_minus_one = (idx[0], idx[1] - 1) - scores_usual = U * ordered[idx] + cumsum[idx_minus_one] + reg - return torch.where(idx[1] == 0, scores_first_rank, scores_usual) + def _calculate_single_label(self, probs, label): + indices, ordered, cumsum = self._sort_sum(probs) + U = torch.rand(indices.shape[0], device=probs.device) + idx = torch.where(indices == label.view(-1, 1)) + reg = torch.maximum(self.__penalty * (idx[1] + 1 - self.__kreg), torch.tensor(0).to(probs.device)) + scores_first_rank = U * ordered[idx] + reg + idx_minus_one = (idx[0], idx[1] - 1) + scores_usual = U * ordered[idx] + cumsum[idx_minus_one] + reg + return torch.where(idx[1] == 0, scores_first_rank, scores_usual) diff --git a/torchcp/classification/scores/saps.py b/torchcp/classification/scores/saps.py index 5ecb629..5983fbf 100644 --- a/torchcp/classification/scores/saps.py +++ b/torchcp/classification/scores/saps.py @@ -25,23 +25,21 @@ def __init__(self, weight): if weight <= 0: raise ValueError("The parameter 'weight' must be a positive value.") self.__weight = weight - - + def _calculate_all_label(self, probs): indices, ordered, cumsum = self._sort_sum(probs) - ordered[:,1:] = self.__weight + ordered[:, 1:] = self.__weight cumsum = torch.cumsum(ordered, dim=-1) U = torch.rand(probs.shape, device=probs.device) ordered_scores = cumsum - ordered * U _, sorted_indices = torch.sort(indices, descending=False, dim=-1) scores = ordered_scores.gather(dim=-1, index=sorted_indices) return scores - - def _calculate_single_label(self, probs, y): + + def _calculate_single_label(self, probs, label): indices, ordered, cumsum = self._sort_sum(probs) - U = torch.rand(indices.shape[0], device = indices.device) - idx = torch.where(indices == y.view(-1, 1)) - scores_first_rank = U * cumsum[idx] - scores_usual = self.__weight * (idx[1] - U) + ordered[:,0] + U = torch.rand(indices.shape[0], device=probs.device) + idx = torch.where(indices == label.view(-1, 1)) + scores_first_rank = U * cumsum[idx] + scores_usual = self.__weight * (idx[1] - U) + ordered[:, 0] return torch.where(idx[1] == 0, scores_first_rank, scores_usual) - diff --git a/torchcp/classification/scores/thr.py b/torchcp/classification/scores/thr.py index ad29d48..ae4a87f 100644 --- a/torchcp/classification/scores/thr.py +++ b/torchcp/classification/scores/thr.py @@ -32,20 +32,18 @@ def __init__(self, score_type="softmax") -> None: else: raise NotImplementedError - def __call__(self, logits, y=None): + def __call__(self, logits, label=None): assert len(logits.shape) <= 2, "The dimension of logits must be less than 2." if len(logits) == 1: logits = logits.unsqueeze(0) temp_values = self.transform(logits) - if y is None: + if label is None: return self.__calculate_all_label(temp_values) else: - return self.__calculate_single_label(temp_values, y) - - def __calculate_single_label(self, temp_values, y): - return 1 - temp_values[torch.arange(y.shape[0], device = temp_values.device), y] - + return self.__calculate_single_label(temp_values, label) + + def __calculate_single_label(self, temp_values, label): + return 1 - temp_values[torch.arange(label.shape[0], device=temp_values.device), label] + def __calculate_all_label(self, temp_values): return 1 - temp_values - - diff --git a/torchcp/regression/predictors/cqr.py b/torchcp/regression/predictors/cqr.py index 468926b..c7c7084 100644 --- a/torchcp/regression/predictors/cqr.py +++ b/torchcp/regression/predictors/cqr.py @@ -23,7 +23,7 @@ def __init__(self, model): super().__init__(model) def calculate_threshold(self, predicts, y_truth, alpha): - if alpha>=1 or alpha<=0: + if alpha >= 1 or alpha <= 0: raise ValueError("Significance level 'alpha' must be in (0,1).") self.scores = torch.maximum(predicts[:, 0] - y_truth, y_truth - predicts[:, 1]) quantile = math.ceil((self.scores.shape[0] + 1) * (1 - alpha)) / self.scores.shape[0] diff --git a/torchcp/regression/predictors/split.py b/torchcp/regression/predictors/split.py index 64b0e10..94b7a11 100644 --- a/torchcp/regression/predictors/split.py +++ b/torchcp/regression/predictors/split.py @@ -42,7 +42,7 @@ def calibrate(self, cal_dataloader, alpha): self.calculate_threshold(predicts, y_truth, alpha) def calculate_threshold(self, predicts, y_truth, alpha): - if alpha>=1 or alpha<=0: + if alpha >= 1 or alpha <= 0: raise ValueError("Significance level 'alpha' must be in (0,1).") self.scores = torch.abs(predicts.reshape(-1) - y_truth) quantile = math.ceil((self.scores.shape[0] + 1) * (1 - alpha)) / self.scores.shape[0]