Skip to content

Commit

Permalink
update confTr and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianguo99 committed Dec 25, 2023
1 parent d9d3010 commit 99bebb9
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 144 deletions.
2 changes: 1 addition & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ We developed TorchCP under Python 3.9 and PyTorch 2.0.1. To install TorchCP, sim

.. code-block:: bash
pip install --index-url https://test.pypi.org/simple/ --no-deps torchcp
pip install torchcp
or clone the repo and run

Expand Down
11 changes: 11 additions & 0 deletions docs/source/torchcp.classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ Predictors
ClusterPredictor
WeightedPredictor

.. automodule:: torchcp.classification.loss
Predictors
-------

.. autosummary::
:nosignatures:

ConfTr

Detailed description
--------------------
Expand Down Expand Up @@ -57,4 +65,7 @@ Detailed description
:members:

.. autoclass:: WeightedPredictor
:members:

.. autoclass:: ConfTr
:members:
153 changes: 67 additions & 86 deletions examples/conformal_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import argparse
import itertools

import torch
import torch.nn as nn
Expand All @@ -29,27 +30,9 @@
from torchcp.classification.scores import THR, APS, SAPS, RAPS
from torchcp.utils import fix_randomness

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Covariate shift')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--predictor', default="Standard", help="Standard")
parser.add_argument('--score', default="THR", help="THR")
parser.add_argument('--loss', default="CE", help="CE | ConfTr")
args = parser.parse_args()
res = {'Coverage_rate': 0, 'Average_size': 0}
num_trials = 1
for seed in range(num_trials):
fix_randomness(seed=seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##################################
# Invalid prediction sets
##################################
train_dataset = build_dataset("mnist")
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, pin_memory=True)


class Net(nn.Module):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 500)
Expand All @@ -60,76 +43,74 @@ def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()



model = Net().to(device)

if args.loss == "CE":
if __name__ == '__main__':
alpha = 0.01
num_trials = 5

result = {}
for loss in ["CE", "ConfTr"]:
print(f"############################## {loss} #########################")
result[loss] = {}
if loss == "CE":
criterion = nn.CrossEntropyLoss()
elif args.loss == "ConfTr":
elif loss == "ConfTr":
predictor = SplitPredictor(score_function=THR(score_type="log_softmax"))
criterion = ConfTr(weights=0.01,
predictor=predictor,
alpha=0.05,
fraction=0.5,
loss_types="valid",
base_loss_fn=nn.CrossEntropyLoss())
predictor=predictor,
alpha=0.05,
fraction=0.5,
loss_types="valid",
base_loss_fn=nn.CrossEntropyLoss())
else:
raise NotImplementedError

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(
f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')


checkpoint_path = f'.cache/conformal_training_model_checkpoint_{args.loss}_seed={seed}.pth'
# if os.path.exists(checkpoint_path):
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
# else:
for epoch in range(1, 10):
train(model, device, train_data_loader, optimizer, epoch)

torch.save({'model_state_dict': model.state_dict(), }, checkpoint_path)

test_dataset = build_dataset("mnist", mode='test')
cal_dataset, test_dataset = torch.utils.data.random_split(test_dataset, [5000, 5000])
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1600, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1600, shuffle=False, pin_memory=True)

if args.score == "THR":
score_function = THR()
elif args.score == "APS":
score_function = APS()
elif args.score == "RAPS":
score_function = RAPS(args.penalty, args.kreg)
elif args.score == "SAPS":
score_function = SAPS(weight=args.weight)

alpha = 0.01
if args.predictor == "Standard":
predictor = SplitPredictor(score_function, model)
elif args.predictor == "ClassWise":
predictor = ClassWisePredictor(score_function, model)
elif args.predictor == "Cluster":
predictor = ClusterPredictor(score_function, model, args.seed)
predictor.calibrate(cal_data_loader, alpha)

# test examples
tmp_res = predictor.evaluate(test_data_loader)
res['Coverage_rate'] += tmp_res['Coverage_rate'] / num_trials
res['Average_size'] += tmp_res['Average_size'] / num_trials

print(res)
for seed in range(num_trials):
fix_randomness(seed=seed)
##################################
# Training a pyotrch model
##################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = build_dataset("mnist")
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, pin_memory=True)
test_dataset = build_dataset("mnist", mode='test')
cal_dataset, test_dataset = torch.utils.data.random_split(test_dataset, [5000, 5000])
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1600, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1600, shuffle=False, pin_memory=True)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(1, 10):
train(model, device, train_data_loader, optimizer, epoch)

for score in ["THR", "APS", "RAPS", "SAPS"]:
if score == "THR":
score_function = THR()
elif score == "APS":
score_function = APS()
elif score == "RAPS":
score_function = RAPS(1, 0)
elif score == "SAPS":
score_function = SAPS(weight=0.2)
if score not in result[loss]:
result[loss][score] = {}
result[loss][score]['Coverage_rate'] = 0
result[loss][score]['Average_size'] = 0
predictor = SplitPredictor(score_function, model)
predictor.calibrate(cal_data_loader, alpha)
tmp_res = predictor.evaluate(test_data_loader)
result[loss][score]['Coverage_rate'] += tmp_res['Coverage_rate'] / num_trials
result[loss][score]['Average_size'] += tmp_res['Average_size'] / num_trials

for score in ["THR", "APS", "RAPS", "SAPS"]:
print(f"Score: {score}. Result is {result[loss][score]}")
28 changes: 8 additions & 20 deletions examples/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

fix_randomness(seed=args.seed)

#######################################
# Loading ImageNet dataset and a pytorch model
#######################################
model_name = 'ResNet101'
# load model
model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)
model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(model_device)
Expand All @@ -55,6 +57,10 @@
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1024, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, pin_memory=True)


#######################################
# A standard process of conformal prediction
#######################################
alpha = args.alpha
print(
f"Experiment--Data : ImageNet, Model : {model_name}, Score : {args.score}, Predictor : {args.predictor}, Alpha : {alpha}")
Expand All @@ -80,22 +86,4 @@
raise NotImplementedError
print(f"The size of calibration set is {len(cal_dataset)}.")
predictor.calibrate(cal_data_loader, alpha)
# predictor.evaluate(test_data_loader)

# test examples
print("Testing examples...")
prediction_sets = []
labels_list = []
with torch.no_grad():
for examples in tqdm(test_data_loader):
tmp_x, tmp_label = examples[0], examples[1]
prediction_sets_batch = predictor.predict(tmp_x)
prediction_sets.extend(prediction_sets_batch)
labels_list.append(tmp_label)
test_labels = torch.cat(labels_list)

metrics = Metrics()
print("Etestuating prediction sets...")
print(f"Coverage_rate: {metrics('coverage_rate')(prediction_sets, test_labels)}.")
print(f"Average_size: {metrics('average_size')(prediction_sets, test_labels)}.")
print(f"CovGap: {metrics('CovGap')(prediction_sets, test_labels, alpha, num_classes)}.")
predictor.evaluate(test_data_loader)
100 changes: 100 additions & 0 deletions tests/test_calssification_logits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


import argparse
import os
import pickle

import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as trn
from tqdm import tqdm

from torchcp.classification.predictors import SplitPredictor, ClusterPredictor, ClassWisePredictor
from torchcp.classification.scores import THR, APS, SAPS, RAPS, Margin
from torchcp.classification.utils.metrics import Metrics
from torchcp.utils import fix_randomness




def test_imagenet():
#######################################
# Loading ImageNet dataset and a pytorch model
#######################################
fix_randomness(seed=0)
model_name = 'ResNet101'
fname = ".cache/" + model_name + ".pkl"
if os.path.exists(fname):
with open(fname, 'rb') as handle:
dataset = pickle.load(handle)

else:
# load dataset
transform = trn.Compose([trn.Resize(256),
trn.CenterCrop(224),
trn.ToTensor(),
trn.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
usr_dir = os.path.expanduser('~')
data_dir = os.path.join(usr_dir, "data")
dataset = dset.ImageFolder(data_dir + "/imagenet/val",
transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=320, shuffle=False, pin_memory=True)

# load model
model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)

logits_list = []
labels_list = []
with torch.no_grad():
for examples in tqdm(data_loader):
tmp_x, tmp_label = examples[0], examples[1]
tmp_logits = model(tmp_x)
logits_list.append(tmp_logits)
labels_list.append(tmp_label)
logits = torch.cat(logits_list)
labels = torch.cat(labels_list)
dataset = torch.utils.data.TensorDataset(logits, labels.long())
with open(fname, 'wb') as handle:
pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)

cal_data, val_data = torch.utils.data.random_split(dataset, [25000, 25000])
cal_logits = torch.stack([sample[0] for sample in cal_data])
cal_labels = torch.stack([sample[1] for sample in cal_data])

test_logits = torch.stack([sample[0] for sample in val_data])
test_labels = torch.stack([sample[1] for sample in val_data])

num_classes = 1000

#######################################
# A standard process of conformal prediction
#######################################
alpha = 0.1
predictors = [SplitPredictor, ClassWisePredictor, ClusterPredictor]
score_functions = [THR(), APS(), RAPS(1, 0), SAPS(0.2), Margin()]
for score in score_functions:
for class_predictor in predictors:
predictor = class_predictor(score)
predictor.calculate_threshold(cal_logits, cal_labels, alpha)
print(f"Experiment--Data : ImageNet, Model : {model_name}, Score : {score.__class__.__name__}, Predictor : {predictor.__class__.__name__}, Alpha : {alpha}")
# print("Testing examples...")
# prediction_sets = []
# for index, ele in enumerate(test_logits):
# prediction_set = predictor.predict_with_logits(ele)
# prediction_sets.append(prediction_set)
prediction_sets = predictor.predict_with_logits(test_logits)

metrics = Metrics()
print("Evaluating prediction sets...")
print(f"Coverage_rate: {metrics('coverage_rate')(prediction_sets, test_labels)}.")
print(f"Average_size: {metrics('average_size')(prediction_sets, test_labels)}.")
print(f"CovGap: {metrics('CovGap')(prediction_sets, test_labels, alpha, num_classes)}.")
Loading

0 comments on commit 99bebb9

Please sign in to comment.