-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0edafd4
commit 5dcb35f
Showing
8 changed files
with
978 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch, torchvision | ||
import torch.nn as nn | ||
import os | ||
# import utils.notebook_utils as nbutils | ||
import models | ||
import utils.notebook_datautils as datutils | ||
import argparse | ||
import utils.notebook_trainutils as trainutil | ||
|
||
|
||
def return_resnet34(nclass): | ||
resnet34 = torchvision.models.resnet34() | ||
resnet34.fc = nn.Linear(512, nclass).to(device) | ||
return resnet34 | ||
|
||
|
||
# SETUP GPU | ||
if not os.path.exists('saved_models'): | ||
os.mkdir('saved_models') | ||
|
||
torch.backends.cudnn.benchmark = True | ||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
download = True | ||
model_dict = {'CIFAR10': models.FastResNet, | ||
'CIFAR100': models.FastResNet, | ||
'DIABETIC_RETINOPATHY': models.DRModel, | ||
'IMAGENETTE': return_resnet34(nclass=10), | ||
'IMAGEWOOF': return_resnet34(nclass=10), | ||
'MNIST': models.LeNet | ||
} | ||
|
||
parser = argparse.ArgumentParser(description='Train multiple models sequentially') | ||
|
||
parser.add_argument("--dataset", type=str, choices=['CIFAR10', 'CIFAR100', 'DIABETIC_RETINOPATHY', 'IMAGEWOOF', | ||
'IMAGENETTE'], default='CIFAR10', help="Name of the dataset") | ||
parser.add_argument("--datadir", type=str, help="Path to dataset") | ||
parser.add_argument("--nmodel", type=int, help="How many models to train (Deep Ensemble)", default=1) | ||
parser.add_argument("--mixup", type=float, help="Alpha for mixup, omit to train without mixup", default=None) | ||
parser.add_argument("--ntrain", type=int, help="How many training example to include, -1 for full dataset", default=-1) | ||
parser.add_argument("--nval", type=int, help="How many validation example to include", default=0) | ||
parser.add_argument("--epoch", type=int, help="Number of epochs to train") | ||
parser.add_argument("--max_lr", type=float, help="Maximum learning rate during LR scheduling (OneCycleLR)") | ||
parser.add_argument("--bsize", type=int, help="Batch size") | ||
parser.add_argument("--wd", type=float, help="Weight decay") | ||
|
||
args = parser.parse_args() | ||
|
||
# Dataset | ||
print(f'Dataset chosen: {args.dataset}\n') | ||
loaders, num_class = datutils.return_loaders(dataset=args.dataset, base=args.datadir, batch_size=args.bsize, | ||
start=args.ntrain, end=args.nval+args.ntrain) | ||
# Architecture | ||
model_def = model_dict[args.dataset] | ||
|
||
for i in range(0, args.nmodel): | ||
model = model_def().to(device) | ||
optimizer, scheduler = trainutil.create_optim_schedule(model, loaders['train'], args.epoch, max_lr=args.max_lr, | ||
weight_decay=args.wd) | ||
criterion = torch.nn.CrossEntropyLoss(reduction='mean') | ||
last_accuracy = trainutil.perform_train(model=model, criterion=criterion, loaders=loaders, optimizer=optimizer, | ||
scheduler=scheduler, mixup=args.mixup, n_epoch=args.epoch, device=device) | ||
|
||
dataset = args.dataset.upper() | ||
savefile = dataset | ||
savefile += '_ntrain-' + str(len(loaders['train'].dataset)) | ||
savefile += '_MixUpAlpha-' + str(args.mixup) | ||
savefile += '_id-' + str(i+1) | ||
checkpoint = {'model_state': model.state_dict(), | ||
'optim_state': optimizer.state_dict(), | ||
'acc': last_accuracy} | ||
torch.save(checkpoint, 'saved_models/' + savefile + '.model') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .diab_retin_kaggle import DRModel | ||
from .fast_resnet import FastResNet | ||
from .lenet import LeNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch.nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class DRModel(nn.Module): | ||
""" | ||
Model that was used as fifth place solution | ||
for diabetic retinopathy dataset. Used from | ||
https://github.com/JeffreyDF/kaggle_diabetic_retinopathy/blob/ | ||
""" | ||
|
||
def __init__(self, num_channel=3, dropout_rate=0.0, num_classes=2): | ||
""" | ||
Model for diabetic retinopathy dataset | ||
:param num_channel: (int) number of channel of input images | ||
:param dropout_rate: (0.0 < float < 1.0) dropout rate | ||
:param num_classes: (int) number of classes to predict | ||
""" | ||
super(DRModel, self).__init__() | ||
self.conv1 = nn.Conv2d(num_channel, 32, kernel_size=7, stride=2, | ||
padding=3, padding_mode='same', bias=True) | ||
self.leakyrelu = nn.LeakyReLU(0.5, inplace=True) | ||
|
||
self.maxpool2d = nn.MaxPool2d(kernel_size=3, stride=3) | ||
|
||
self.maxpool1d = nn.MaxPool1d(kernel_size=2, stride=2) | ||
|
||
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv4 = nn.Conv2d(32, 64, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv6 = nn.Conv2d(64, 128, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv10 = nn.Conv2d(128, 256, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv11 = nn.Conv2d(256, 256, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv12 = nn.Conv2d(256, 256, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
self.conv13 = nn.Conv2d(256, 256, kernel_size=3, stride=1, | ||
padding=1, padding_mode='same', bias=True) | ||
|
||
self.dropout = nn.Dropout(p=dropout_rate) | ||
|
||
self.bn1 = nn.BatchNorm2d(32) | ||
self.bn2 = nn.BatchNorm2d(32) | ||
self.bn3 = nn.BatchNorm2d(32) | ||
self.bn4 = nn.BatchNorm2d(64) | ||
self.bn5 = nn.BatchNorm2d(64) | ||
self.bn6 = nn.BatchNorm2d(128) | ||
self.bn7 = nn.BatchNorm2d(128) | ||
self.bn8 = nn.BatchNorm2d(128) | ||
self.bn9 = nn.BatchNorm2d(128) | ||
self.bn10 = nn.BatchNorm2d(256) | ||
self.bn11 = nn.BatchNorm2d(256) | ||
self.bn12 = nn.BatchNorm2d(256) | ||
self.bn13 = nn.BatchNorm2d(256) | ||
|
||
self.fc1 = nn.Linear(256, 128) | ||
self.fc2 = nn.Linear(128, num_classes) | ||
|
||
def forward(self, x): | ||
|
||
batch_size = x.size(0) | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.maxpool2d(x) | ||
|
||
x = self.conv2(x) | ||
x = self.bn2(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv3(x) | ||
x = self.bn3(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.maxpool2d(x) | ||
|
||
x = self.conv4(x) | ||
x = self.bn4(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv5(x) | ||
x = self.bn5(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.maxpool2d(x) | ||
|
||
x = self.conv6(x) | ||
x = self.bn6(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv7(x) | ||
x = self.bn7(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv8(x) | ||
x = self.bn8(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv9(x) | ||
x = self.bn9(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.maxpool2d(x) | ||
|
||
x = self.conv10(x) | ||
x = self.bn10(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv11(x) | ||
x = self.bn11(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv12(x) | ||
x = self.bn12(x) | ||
x = self.leakyrelu(x) | ||
x = self.conv13(x) | ||
x = self.bn13(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.maxpool2d(x) | ||
x = x.view(batch_size, -1) | ||
|
||
x = self.dropout(x) | ||
x = self.fc1(x) | ||
|
||
x = F.relu(x) | ||
x = self.fc2(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class Flatten(torch.nn.Module): | ||
def forward(self, x): | ||
return x.view(x.size(0), -1) | ||
|
||
|
||
def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1, bn=True, activation=True, bias=True): | ||
""" | ||
Definition of a conv+BN block with residual connection | ||
:param channels_in: (int) input channel dimension | ||
:param channels_out: (int) output channel dimension | ||
:param kernel_size: (int) size of convolutional kernel | ||
:param stride: (int) stride of convolutional kernel | ||
:param padding: (int) padding of conv | ||
:param groups: (int) groups of conv | ||
:param bn: (int) whether to apply batchnorm | ||
:param activation: (int) whether to add ReLU activation | ||
:param bias: (int) whether to add bias for convolution | ||
:return: module | ||
""" | ||
op = [ | ||
torch.nn.Conv2d(channels_in, channels_out, | ||
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias), | ||
] | ||
if bn: | ||
op.append(torch.nn.BatchNorm2d(channels_out)) | ||
if activation: | ||
op.append(torch.nn.ReLU(inplace=True)) | ||
return torch.nn.Sequential(*op) | ||
|
||
|
||
class Residual(torch.nn.Module): | ||
def __init__(self, module): | ||
super(Residual, self).__init__() | ||
self.module = module | ||
|
||
def forward(self, x): | ||
return x + self.module(x) | ||
|
||
|
||
class FastResNet(nn.Module): | ||
|
||
def __init__(self, num_class=10, sphere_projected=False, last_relu=True, bias=False, dropout=None): | ||
""" | ||
Fast version of resnet adapted from the blog: https://myrtle.ai/learn/how-to-train-your-resnet/ | ||
:param num_class: (int) number of classes | ||
:param sphere_projected: (bool) whether to project to unit sphere | ||
:param last_relu: (bool) whether apply ReLU at the end | ||
:param bias: (bool) add bias to fully connected or not | ||
:param dropout: (None or float) add dropout before FC layer | ||
""" | ||
super(FastResNet, self).__init__() | ||
self.encoder = torch.nn.Sequential( | ||
conv_bn(3, 64, kernel_size=3, stride=1, padding=1), | ||
conv_bn(64, 128, kernel_size=5, stride=2, padding=2), | ||
# torch.nn.MaxPool2d(2), | ||
|
||
Residual(torch.nn.Sequential( | ||
conv_bn(128, 128), | ||
conv_bn(128, 128), | ||
)), | ||
|
||
conv_bn(128, 256, kernel_size=3, stride=1, padding=1), | ||
torch.nn.MaxPool2d(2), | ||
|
||
Residual(torch.nn.Sequential( # try from here | ||
conv_bn(256, 256), | ||
conv_bn(256, 256), | ||
)), | ||
|
||
conv_bn(256, 128, kernel_size=3, stride=1, padding=0, activation=last_relu), | ||
|
||
torch.nn.AdaptiveMaxPool2d((1, 1)), | ||
Flatten()) | ||
self.dropout = None if dropout is None else nn.Dropout(p=dropout) | ||
self.fc = torch.nn.Linear(128, num_class, bias=bias) | ||
self.sphere = sphere_projected | ||
|
||
def forward(self, x): | ||
x = self.encoder(x) | ||
if hasattr(self, 'dropout') and self.dropout is not None: | ||
x = self.dropout(x) | ||
if self.sphere: | ||
norms = torch.norm(x, dim=1, keepdim=True) | ||
x = x / norms | ||
x = self.fc(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch | ||
|
||
|
||
class LeNet(nn.Module): | ||
def __init__(self): | ||
super(LeNet, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.fc1 = nn.Linear(9216, 128) | ||
self.fc2 = nn.Linear(128, 10) | ||
|
||
# self.conv1 = nn.Conv2d(3, 6, 5) | ||
# self.conv2 = nn.Conv2d(6, 16, 5) | ||
# self.fc1 = nn.Linear(16*5*5, 120) | ||
# self.fc2 = nn.Linear(120, 84) | ||
# self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = F.relu(x) | ||
x = self.conv2(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
x = torch.flatten(x, 1) | ||
x = self.fc1(x) | ||
x = F.relu(x) | ||
x = self.fc2(x) | ||
# out = F.relu(self.conv1(x)) | ||
# out = F.max_pool2d(out, 2) | ||
# out = F.relu(self.conv2(out)) | ||
# out = F.max_pool2d(out, 2) | ||
# out = out.view(out.size(0), -1) | ||
# out = F.relu(self.fc1(out)) | ||
# out = F.relu(self.fc2(out)) | ||
# out = self.fc3(out) | ||
return x |
Oops, something went wrong.