From 99bebb93d9be11c2f3470bbfc429c86b58eb43c3 Mon Sep 17 00:00:00 2001 From: huangjg Date: Mon, 25 Dec 2023 15:06:18 +0800 Subject: [PATCH] update confTr and docs --- docs/source/installation.rst | 2 +- docs/source/torchcp.classification.rst | 11 ++ examples/conformal_training.py | 153 ++++++++++------------ examples/imagenet_example.py | 28 ++-- tests/test_calssification_logits.py | 100 ++++++++++++++ torchcp/classification/loss/conftr.py | 69 +++++----- torchcp/classification/predictors/base.py | 2 + 7 files changed, 221 insertions(+), 144 deletions(-) create mode 100644 tests/test_calssification_logits.py diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 586374d..80c4157 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -9,7 +9,7 @@ We developed TorchCP under Python 3.9 and PyTorch 2.0.1. To install TorchCP, sim .. code-block:: bash - pip install --index-url https://test.pypi.org/simple/ --no-deps torchcp + pip install torchcp or clone the repo and run diff --git a/docs/source/torchcp.classification.rst b/docs/source/torchcp.classification.rst index e06fd95..0b33224 100644 --- a/docs/source/torchcp.classification.rst +++ b/docs/source/torchcp.classification.rst @@ -26,6 +26,14 @@ Predictors ClusterPredictor WeightedPredictor +.. automodule:: torchcp.classification.loss +Predictors +------- + +.. autosummary:: + :nosignatures: + + ConfTr Detailed description -------------------- @@ -57,4 +65,7 @@ Detailed description :members: .. autoclass:: WeightedPredictor + :members: + +.. autoclass:: ConfTr :members: \ No newline at end of file diff --git a/examples/conformal_training.py b/examples/conformal_training.py index 8b4e245..99c367d 100644 --- a/examples/conformal_training.py +++ b/examples/conformal_training.py @@ -17,6 +17,7 @@ import argparse +import itertools import torch import torch.nn as nn @@ -29,27 +30,9 @@ from torchcp.classification.scores import THR, APS, SAPS, RAPS from torchcp.utils import fix_randomness -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Covariate shift') - parser.add_argument('--seed', default=0, type=int) - parser.add_argument('--predictor', default="Standard", help="Standard") - parser.add_argument('--score', default="THR", help="THR") - parser.add_argument('--loss', default="CE", help="CE | ConfTr") - args = parser.parse_args() - res = {'Coverage_rate': 0, 'Average_size': 0} - num_trials = 1 - for seed in range(num_trials): - fix_randomness(seed=seed) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - ################################## - # Invalid prediction sets - ################################## - train_dataset = build_dataset("mnist") - train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, pin_memory=True) - class Net(nn.Module): +class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 500) @@ -60,76 +43,74 @@ def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x + +def train(model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + - - model = Net().to(device) - - if args.loss == "CE": +if __name__ == '__main__': + alpha = 0.01 + num_trials = 5 + + result = {} + for loss in ["CE", "ConfTr"]: + print(f"############################## {loss} #########################") + result[loss] = {} + if loss == "CE": criterion = nn.CrossEntropyLoss() - elif args.loss == "ConfTr": + elif loss == "ConfTr": predictor = SplitPredictor(score_function=THR(score_type="log_softmax")) criterion = ConfTr(weights=0.01, - predictor=predictor, - alpha=0.05, - fraction=0.5, - loss_types="valid", - base_loss_fn=nn.CrossEntropyLoss()) + predictor=predictor, + alpha=0.05, + fraction=0.5, + loss_types="valid", + base_loss_fn=nn.CrossEntropyLoss()) else: raise NotImplementedError - - optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - - - def train(model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = criterion(output, target) - loss.backward() - optimizer.step() - if batch_idx % 10 == 0: - print( - f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}') - - - checkpoint_path = f'.cache/conformal_training_model_checkpoint_{args.loss}_seed={seed}.pth' - # if os.path.exists(checkpoint_path): - # checkpoint = torch.load(checkpoint_path) - # model.load_state_dict(checkpoint['model_state_dict']) - # else: - for epoch in range(1, 10): - train(model, device, train_data_loader, optimizer, epoch) - - torch.save({'model_state_dict': model.state_dict(), }, checkpoint_path) - - test_dataset = build_dataset("mnist", mode='test') - cal_dataset, test_dataset = torch.utils.data.random_split(test_dataset, [5000, 5000]) - cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1600, shuffle=False, pin_memory=True) - test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1600, shuffle=False, pin_memory=True) - - if args.score == "THR": - score_function = THR() - elif args.score == "APS": - score_function = APS() - elif args.score == "RAPS": - score_function = RAPS(args.penalty, args.kreg) - elif args.score == "SAPS": - score_function = SAPS(weight=args.weight) - - alpha = 0.01 - if args.predictor == "Standard": - predictor = SplitPredictor(score_function, model) - elif args.predictor == "ClassWise": - predictor = ClassWisePredictor(score_function, model) - elif args.predictor == "Cluster": - predictor = ClusterPredictor(score_function, model, args.seed) - predictor.calibrate(cal_data_loader, alpha) - - # test examples - tmp_res = predictor.evaluate(test_data_loader) - res['Coverage_rate'] += tmp_res['Coverage_rate'] / num_trials - res['Average_size'] += tmp_res['Average_size'] / num_trials - - print(res) + for seed in range(num_trials): + fix_randomness(seed=seed) + ################################## + # Training a pyotrch model + ################################## + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_dataset = build_dataset("mnist") + train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, pin_memory=True) + test_dataset = build_dataset("mnist", mode='test') + cal_dataset, test_dataset = torch.utils.data.random_split(test_dataset, [5000, 5000]) + cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1600, shuffle=False, pin_memory=True) + test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1600, shuffle=False, pin_memory=True) + + model = Net().to(device) + optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + for epoch in range(1, 10): + train(model, device, train_data_loader, optimizer, epoch) + + for score in ["THR", "APS", "RAPS", "SAPS"]: + if score == "THR": + score_function = THR() + elif score == "APS": + score_function = APS() + elif score == "RAPS": + score_function = RAPS(1, 0) + elif score == "SAPS": + score_function = SAPS(weight=0.2) + if score not in result[loss]: + result[loss][score] = {} + result[loss][score]['Coverage_rate'] = 0 + result[loss][score]['Average_size'] = 0 + predictor = SplitPredictor(score_function, model) + predictor.calibrate(cal_data_loader, alpha) + tmp_res = predictor.evaluate(test_data_loader) + result[loss][score]['Coverage_rate'] += tmp_res['Coverage_rate'] / num_trials + result[loss][score]['Average_size'] += tmp_res['Average_size'] / num_trials + + for score in ["THR", "APS", "RAPS", "SAPS"]: + print(f"Score: {score}. Result is {result[loss][score]}") diff --git a/examples/imagenet_example.py b/examples/imagenet_example.py index 1e911b0..1652933 100644 --- a/examples/imagenet_example.py +++ b/examples/imagenet_example.py @@ -34,8 +34,10 @@ fix_randomness(seed=args.seed) + ####################################### + # Loading ImageNet dataset and a pytorch model + ####################################### model_name = 'ResNet101' - # load model model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True) model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(model_device) @@ -55,6 +57,10 @@ cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1024, shuffle=False, pin_memory=True) test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, pin_memory=True) + + ####################################### + # A standard process of conformal prediction + ####################################### alpha = args.alpha print( f"Experiment--Data : ImageNet, Model : {model_name}, Score : {args.score}, Predictor : {args.predictor}, Alpha : {alpha}") @@ -80,22 +86,4 @@ raise NotImplementedError print(f"The size of calibration set is {len(cal_dataset)}.") predictor.calibrate(cal_data_loader, alpha) - # predictor.evaluate(test_data_loader) - - # test examples - print("Testing examples...") - prediction_sets = [] - labels_list = [] - with torch.no_grad(): - for examples in tqdm(test_data_loader): - tmp_x, tmp_label = examples[0], examples[1] - prediction_sets_batch = predictor.predict(tmp_x) - prediction_sets.extend(prediction_sets_batch) - labels_list.append(tmp_label) - test_labels = torch.cat(labels_list) - - metrics = Metrics() - print("Etestuating prediction sets...") - print(f"Coverage_rate: {metrics('coverage_rate')(prediction_sets, test_labels)}.") - print(f"Average_size: {metrics('average_size')(prediction_sets, test_labels)}.") - print(f"CovGap: {metrics('CovGap')(prediction_sets, test_labels, alpha, num_classes)}.") + predictor.evaluate(test_data_loader) diff --git a/tests/test_calssification_logits.py b/tests/test_calssification_logits.py new file mode 100644 index 0000000..6648e13 --- /dev/null +++ b/tests/test_calssification_logits.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import argparse +import os +import pickle + +import torch +import torchvision +import torchvision.datasets as dset +import torchvision.transforms as trn +from tqdm import tqdm + +from torchcp.classification.predictors import SplitPredictor, ClusterPredictor, ClassWisePredictor +from torchcp.classification.scores import THR, APS, SAPS, RAPS, Margin +from torchcp.classification.utils.metrics import Metrics +from torchcp.utils import fix_randomness + + + + +def test_imagenet(): + ####################################### + # Loading ImageNet dataset and a pytorch model + ####################################### + fix_randomness(seed=0) + model_name = 'ResNet101' + fname = ".cache/" + model_name + ".pkl" + if os.path.exists(fname): + with open(fname, 'rb') as handle: + dataset = pickle.load(handle) + + else: + # load dataset + transform = trn.Compose([trn.Resize(256), + trn.CenterCrop(224), + trn.ToTensor(), + trn.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + usr_dir = os.path.expanduser('~') + data_dir = os.path.join(usr_dir, "data") + dataset = dset.ImageFolder(data_dir + "/imagenet/val", + transform) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=320, shuffle=False, pin_memory=True) + + # load model + model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True) + + logits_list = [] + labels_list = [] + with torch.no_grad(): + for examples in tqdm(data_loader): + tmp_x, tmp_label = examples[0], examples[1] + tmp_logits = model(tmp_x) + logits_list.append(tmp_logits) + labels_list.append(tmp_label) + logits = torch.cat(logits_list) + labels = torch.cat(labels_list) + dataset = torch.utils.data.TensorDataset(logits, labels.long()) + with open(fname, 'wb') as handle: + pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL) + + cal_data, val_data = torch.utils.data.random_split(dataset, [25000, 25000]) + cal_logits = torch.stack([sample[0] for sample in cal_data]) + cal_labels = torch.stack([sample[1] for sample in cal_data]) + + test_logits = torch.stack([sample[0] for sample in val_data]) + test_labels = torch.stack([sample[1] for sample in val_data]) + + num_classes = 1000 + + ####################################### + # A standard process of conformal prediction + ####################################### + alpha = 0.1 + predictors = [SplitPredictor, ClassWisePredictor, ClusterPredictor] + score_functions = [THR(), APS(), RAPS(1, 0), SAPS(0.2), Margin()] + for score in score_functions: + for class_predictor in predictors: + predictor = class_predictor(score) + predictor.calculate_threshold(cal_logits, cal_labels, alpha) + print(f"Experiment--Data : ImageNet, Model : {model_name}, Score : {score.__class__.__name__}, Predictor : {predictor.__class__.__name__}, Alpha : {alpha}") + # print("Testing examples...") + # prediction_sets = [] + # for index, ele in enumerate(test_logits): + # prediction_set = predictor.predict_with_logits(ele) + # prediction_sets.append(prediction_set) + prediction_sets = predictor.predict_with_logits(test_logits) + + metrics = Metrics() + print("Evaluating prediction sets...") + print(f"Coverage_rate: {metrics('coverage_rate')(prediction_sets, test_labels)}.") + print(f"Average_size: {metrics('average_size')(prediction_sets, test_labels)}.") + print(f"CovGap: {metrics('CovGap')(prediction_sets, test_labels, alpha, num_classes)}.") diff --git a/torchcp/classification/loss/conftr.py b/torchcp/classification/loss/conftr.py index 9978302..e4d8fb6 100644 --- a/torchcp/classification/loss/conftr.py +++ b/torchcp/classification/loss/conftr.py @@ -12,48 +12,48 @@ class ConfTr(nn.Module): - def __init__(self, weights, predictor, alpha, fraction, loss_types="valid", target_size=1, + """ + Conformal Training (Stutz et al., 2021). + Paper: https://arxiv.org/abs/2110.09192. + + + :param weights: the weight of each loss function + :param predictor: the CP predictors + :param alpha: the significance level for each training batch + :param fraction: the fraction of the calibration set in each training batch + :param loss_types: the selected (multi-selected) loss functions, which can be "valid", "classification", "probs", "coverage". + :param target_size: Optional: 0 | 1. + :param loss_transform: a transform for loss + :param base_loss_fn: a base loss function. For example, cross entropy in classification. + """ + def __init__(self, weight, predictor, alpha, fraction, loss_type="valid", target_size=1, loss_transform="square", base_loss_fn=None): - """ - :param weights: the weight of each loss function - :param predictor: the CP predictors - :param alpha: the significance level for each training batch - :param fraction: the fraction of the calibration set in each training batch - :param loss_types: the selected (multi-selected) loss functions, which can be "valid", "classification", "probs", "coverage". - :param target_size: - :param loss_transform: a transform for loss - :param base_loss_fn: a base loss function, such as cross entropy for classification - """ + super(ConfTr, self).__init__() - self.weight = weights + assert weight>0, "weight must be greater than 0." + assert (fraction > 0 and fraction<1), "fraction should be a value in (0,1)." + assert loss_type in ["valid", "classification", "probs", "coverage"], 'loss_type should be a value in ["valid", "classification", "probs", "coverage"].' + assert target_size==0 or target_size ==1, "target_size should be 0 or 1." + assert loss_transform in ["square", "abs", "log"], 'loss_transform should be a value in ["square", "abs", "log"].' + self.weight = weight self.predictor = predictor self.alpha = alpha self.fraction = fraction - self.base_loss_fn = base_loss_fn - + self.loss_type = loss_type self.target_size = target_size + self.base_loss_fn = base_loss_fn + if loss_transform == "square": self.transform = torch.square elif loss_transform == "abs": self.transform = torch.abs elif loss_transform == "log": self.transform = torch.log - else: - raise NotImplementedError self.loss_functions_dict = {"valid": self.__compute_hinge_size_loss, "probs": self.__compute_probabilistic_size_loss, "coverage": self.__compute_coverage_loss, - "classification": self.__compute_classification_loss} - - if type(loss_types) == set: - if type(weights) != set: - raise TypeError("weights must be a set.") - elif type(loss_types) == str: - if type(weights) != float and type(weights) != int: - raise TypeError("weights must be a float or a int.") - else: - raise TypeError("types must be a set or a string.") - self.loss_types = loss_types + "classification": self.__compute_classification_loss + } def forward(self, logits, labels): # Compute Size Loss @@ -65,15 +65,10 @@ def forward(self, logits, labels): self.predictor.calculate_threshold(cal_logits.detach(), cal_labels.detach(), self.alpha) tau = self.predictor.q_hat - test_scores = self.predictor.score_function.predict(test_logits) + test_scores = self.predictor.score_function(test_logits) + # Computing the probability of each label contained in the prediction set. pred_sets = torch.sigmoid(tau - test_scores) - - if type(self.loss_types) == set: - loss = torch.tensor(0).to(logits.device) - for i in range(len(self.loss_types)): - loss += self.weight[i] * self.loss_functions_dict[self.loss_types[i]](pred_sets, test_labels) - else: - loss = self.weight * self.loss_functions_dict[self.loss_types](pred_sets, test_labels) + loss = self.weight * self.loss_functions_dict[self.loss_type](pred_sets, test_labels) if self.base_loss_fn is not None: loss += self.base_loss_fn(logits, labels).float() @@ -109,13 +104,13 @@ def __compute_coverage_loss(self, pred_sets, labels): def __compute_classification_loss(self, pred_sets, labels): # Convert labels to one-hot encoding one_hot_labels = F.one_hot(labels, num_classes=pred_sets.shape[1]).float() - loss_matrix = torch.eye(pred_sets.shape[1]).to(pred_sets.device) + loss_matrix = torch.eye(pred_sets.shape[1], device=pred_sets.device) # Calculate l1 and l2 losses l1 = (1 - pred_sets) * one_hot_labels * loss_matrix[labels] l2 = pred_sets * (1 - one_hot_labels) * loss_matrix[labels] # Calculate the total loss - loss = torch.sum(torch.maximum(l1 + l2, torch.zeros_like(l1).to(pred_sets.device)), dim=1) + loss = torch.sum(torch.maximum(l1 + l2, torch.zeros_like(l1, device=pred_sets.device)), dim=1) # Return the mean loss return torch.mean(loss) diff --git a/torchcp/classification/predictors/base.py b/torchcp/classification/predictors/base.py index ae954c0..528efd2 100644 --- a/torchcp/classification/predictors/base.py +++ b/torchcp/classification/predictors/base.py @@ -24,6 +24,8 @@ class BasePredictor(object): def __init__(self, score_function, model=None, temperature=1): """ + Abstract base class for all conformal predictors. + :param score_function: non-conformity score function. :param model: a deep learning model. """