Skip to content

Commit

Permalink
Merge pull request #10 from ml-stat-Sustech/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
hongxin001 authored Dec 26, 2023
2 parents 5841e16 + 3e3db8f commit 018ae1b
Show file tree
Hide file tree
Showing 34 changed files with 927 additions and 572 deletions.
37 changes: 37 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,41 @@
# The master toctree document.
master_doc = 'index'

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'

todo_include_todos = False

# Resolve function for the linkcode extension.
def linkcode_resolve(domain, info):
def find_source():
# try to find the file and line number, based on code from numpy:
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
obj = sys.modules[info['module']]
for part in info['fullname'].split('.'):
obj = getattr(obj, part)
import inspect
import os
fn = inspect.getsourcefile(obj)
fn = os.path.relpath(fn, start=os.path.dirname(torchcp.__file__))
source, lineno = inspect.getsourcelines(obj)
return fn, lineno, lineno + len(source) - 1

if domain != 'py' or not info['module']:
return None
try:
filename = 'torchcp/%s#L%d-L%d' % find_source()
except Exception:
filename = info['module'].replace('.', '/') + '.py'
tag = 'master'
url = "https://github.com/ml-stat-Sustech/TorchCP/blob/%s/%s"
return url % (tag, filename)

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

Expand All @@ -95,3 +130,5 @@

# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True


2 changes: 1 addition & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ We developed TorchCP under Python 3.9 and PyTorch 2.0.1. To install TorchCP, sim

.. code-block:: bash
pip install --index-url https://test.pypi.org/simple/ --no-deps torchcp
pip install torchcp
or clone the repo and run

Expand Down
14 changes: 13 additions & 1 deletion docs/source/torchcp.classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ Predictors
ClusterPredictor
WeightedPredictor

.. automodule:: torchcp.classification.loss
Loss functions
-------

.. autosummary::
:nosignatures:

ConfTr

Detailed description
--------------------
Expand Down Expand Up @@ -57,4 +65,8 @@ Detailed description
:members:

.. autoclass:: WeightedPredictor
:members:
:members:

.. autoclass:: ConfTr
:members:

13 changes: 13 additions & 0 deletions docs/source/torchcp.regression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ Predictors
cqr
ACI

.. automodule:: torchcp.regression.loss
Loss functions
-------

.. autosummary::
:nosignatures:

QuantileLoss


Detailed description
--------------------
Expand All @@ -24,3 +33,7 @@ Detailed description

.. autoclass:: ACI
:members:

.. autoclass:: QuantileLoss
:members:

144 changes: 52 additions & 92 deletions examples/conformal_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import argparse
import itertools

import torch
import torch.nn as nn
Expand All @@ -29,27 +30,9 @@
from torchcp.classification.scores import THR, APS, SAPS, RAPS
from torchcp.utils import fix_randomness

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Covariate shift')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--predictor', default="Standard", help="Standard")
parser.add_argument('--score', default="THR", help="THR")
parser.add_argument('--loss', default="CE", help="CE | ConfTr")
args = parser.parse_args()
res = {'Coverage_rate': 0, 'Average_size': 0}
num_trials = 1
for seed in range(num_trials):
fix_randomness(seed=seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##################################
# Invalid prediction sets
##################################
train_dataset = build_dataset("mnist")
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, pin_memory=True)


class Net(nn.Module):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 500)
Expand All @@ -60,77 +43,54 @@ def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

def train(model, device, train_loader,criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()



model = Net().to(device)

if args.loss == "CE":
criterion = nn.CrossEntropyLoss()
elif args.loss == "ConfTr":
predictor = SplitPredictor(score_function=THR(score_type="log_softmax"))
criterion = ConfTr(weights=0.01,
predictor=predictor,
alpha=0.05,
device=device,
fraction=0.5,
loss_types="valid",
base_loss_fn=nn.CrossEntropyLoss())
else:
raise NotImplementedError

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(
f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')


checkpoint_path = f'.cache/conformal_training_model_checkpoint_{args.loss}_seed={seed}.pth'
# if os.path.exists(checkpoint_path):
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
# else:
for epoch in range(1, 10):
train(model, device, train_data_loader, optimizer, epoch)

torch.save({'model_state_dict': model.state_dict(), }, checkpoint_path)

test_dataset = build_dataset("mnist", mode='test')
cal_dataset, test_dataset = torch.utils.data.random_split(test_dataset, [5000, 5000])
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1600, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1600, shuffle=False, pin_memory=True)

if args.score == "THR":
score_function = THR()
elif args.score == "APS":
score_function = APS()
elif args.score == "RAPS":
score_function = RAPS(args.penalty, args.kreg)
elif args.score == "SAPS":
score_function = SAPS(weight=args.weight)

alpha = 0.01
if args.predictor == "Standard":
predictor = SplitPredictor(score_function, model)
elif args.predictor == "ClassWise":
predictor = ClassWisePredictor(score_function, model)
elif args.predictor == "Cluster":
predictor = ClusterPredictor(score_function, model, args.seed)
predictor.calibrate(cal_data_loader, alpha)

# test examples
tmp_res = predictor.evaluate(test_data_loader)
res['Coverage_rate'] += tmp_res['Coverage_rate'] / num_trials
res['Average_size'] += tmp_res['Average_size'] / num_trials

print(res)
if __name__ == '__main__':
alpha = 0.01
num_trials = 5
loss = "ConfTr"
result = {}
print(f"############################## {loss} #########################")

predictor = SplitPredictor(score_function=THR(score_type="log_softmax"))
criterion = ConfTr(weight=0.01,
predictor=predictor,
alpha=0.05,
fraction=0.5,
loss_type="valid",
base_loss_fn=nn.CrossEntropyLoss())

fix_randomness(seed=0)
##################################
# Training a pytorch model
##################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = build_dataset("mnist")
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, pin_memory=True)
test_dataset = build_dataset("mnist", mode='test')
cal_dataset, test_dataset = torch.utils.data.random_split(test_dataset, [5000, 5000])
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1600, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1600, shuffle=False, pin_memory=True)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(1, 10):
train(model, device, train_data_loader, criterion, optimizer, epoch)


score_function = THR()

predictor = SplitPredictor(score_function, model)
predictor.calibrate(cal_data_loader, alpha)
result = predictor.evaluate(test_data_loader)
print(f"Result--Coverage_rate: {result['Coverage_rate']}, Average_size: {result['Average_size']}")
2 changes: 1 addition & 1 deletion examples/covariate_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchcp.utils import fix_randomness

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Covariate shift')
parser = argparse.ArgumentParser(description='Coveriate shift')
parser.add_argument('--seed', default=0, type=int)
args = parser.parse_args()

Expand Down
6 changes: 3 additions & 3 deletions examples/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def build_dataset(dataset_name, transform=None, mode="train"):
data_dir = os.path.join(usr_dir, "data")

if dataset_name == 'imagenet':
if transform == None:
if transform is None:
transform = trn.Compose([
trn.Resize(256),
trn.CenterCrop(224),
Expand All @@ -25,7 +25,7 @@ def build_dataset(dataset_name, transform=None, mode="train"):
dataset = dset.ImageFolder(data_dir + "/imagenet/val",
transform)
elif dataset_name == 'imagenetv2':
if transform == None:
if transform is None:
transform = trn.Compose([
trn.Resize(256),
trn.CenterCrop(224),
Expand All @@ -38,7 +38,7 @@ def build_dataset(dataset_name, transform=None, mode="train"):
transform)

elif dataset_name == 'mnist':
if transform == None:
if transform is None:
transform = trn.Compose([
trn.ToTensor(),
trn.Normalize((0.1307,), (0.3081,))
Expand Down
59 changes: 11 additions & 48 deletions examples/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@
parser = argparse.ArgumentParser(description='')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--alpha', default=0.1, type=float)
parser.add_argument('--predictor', default="Standard", help="Standard | ClassWise | Cluster")
parser.add_argument('--score', default="THR", help="THR | APS | SAPS")
parser.add_argument('--penalty', default=1, type=float)
parser.add_argument('--kreg', default=0, type=int)
parser.add_argument('--weight', default=0.2, type=int)
parser.add_argument('--split', default="random", type=str, help="proportional | doubledip | random")
args = parser.parse_args()

fix_randomness(seed=args.seed)

#######################################
# Loading ImageNet dataset and a pytorch model
#######################################
model_name = 'ResNet101'
# load model
model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)
model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(model_device)
Expand All @@ -55,47 +51,14 @@
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1024, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, pin_memory=True)


#######################################
# A standard process of conformal prediction
#######################################
alpha = args.alpha
print(
f"Experiment--Data : ImageNet, Model : {model_name}, Score : {args.score}, Predictor : {args.predictor}, Alpha : {alpha}")
num_classes = 1000
if args.score == "THR":
score_function = THR()
elif args.score == "APS":
score_function = APS()
elif args.score == "RAPS":
score_function = RAPS(args.penalty, args.kreg)
elif args.score == "SAPS":
score_function = SAPS(weight=args.weight)
else:
raise NotImplementedError

if args.predictor == "Standard":
predictor = SplitPredictor(score_function, model)
elif args.predictor == "ClassWise":
predictor = ClassWisePredictor(score_function, model)
elif args.predictor == "Cluster":
predictor = ClusterPredictor(score_function, model, args.seed)
else:
raise NotImplementedError
print(f"Experiment--Data : ImageNet, Model : {model_name}, Score : THR, Predictor : SplitPredictor, Alpha : {alpha}")
score_function = THR()
predictor = SplitPredictor(score_function, model)
print(f"The size of calibration set is {len(cal_dataset)}.")
predictor.calibrate(cal_data_loader, alpha)
# predictor.evaluate(test_data_loader)

# test examples
print("Testing examples...")
prediction_sets = []
labels_list = []
with torch.no_grad():
for examples in tqdm(test_data_loader):
tmp_x, tmp_label = examples[0], examples[1]
prediction_sets_batch = predictor.predict(tmp_x)
prediction_sets.extend(prediction_sets_batch)
labels_list.append(tmp_label)
test_labels = torch.cat(labels_list)

metrics = Metrics()
print("Etestuating prediction sets...")
print(f"Coverage_rate: {metrics('coverage_rate')(prediction_sets, test_labels)}.")
print(f"Average_size: {metrics('average_size')(prediction_sets, test_labels)}.")
print(f"CovGap: {metrics('CovGap')(prediction_sets, test_labels, alpha, num_classes)}.")
predictor.evaluate(test_data_loader)
Loading

0 comments on commit 018ae1b

Please sign in to comment.