Skip to content

Commit

Permalink
update dot: a distillation-oriented trainer & configs
Browse files Browse the repository at this point in the history
  • Loading branch information
Zzzzz1 committed Nov 5, 2023
1 parent d3f5e99 commit a08d46f
Show file tree
Hide file tree
Showing 20 changed files with 805 additions and 10 deletions.
Binary file added .github/dot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@ This repo is

(2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679).

(3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf).


# DOT: A Distillation-Oriented Trainer

### Framework

<div style="text-align:center"><img src=".github/dot.png" width="80%" ></div>

### Main Benchmark Results

On CIFAR-100:

| Teacher <br> Student | ResNet32x4 <br> ResNet8x4| VGG13 <br> VGG8| ResNet32x4 <br> ShuffleNet-V2|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|
| KD | 73.33 | 72.98 | 74.45 |
| **KD+DOT** | **75.12** | **73.77** | **75.55** |

On Tiny-ImageNet:

| Teacher <br> Student |ResNet18 <br> MobileNet-V2|ResNet18 <br> ShuffleNet-V2|
|:---------------:|:-----------------:|:-----------------:|
| KD | 58.35 | 62.26 |
| **KD+DOT** | **64.01** | **65.75** |

On ImageNet:

| Teacher <br> Student |ResNet34 <br> ResNet18|ResNet50 <br> MobileNet-V1|
|:---------------:|:-----------------:|:-----------------:|
| KD | 71.03 | 70.50 |
| **KD+DOT** | **71.72** | **73.09** |

# Decoupled Knowledge Distillation

Expand Down Expand Up @@ -170,6 +201,12 @@ If this repo is helpful for your research, please consider citing the paper:
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
@article{zhao2023dot,
title={DOT: A Distillation-Oriented Trainer},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun},
journal={arXiv preprint arXiv:2307.08436},
year={2023}
}
```

# License
Expand Down
20 changes: 20 additions & 0 deletions configs/cifar100/dot/res32x4_res8x4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,res32x4,res8x4"
PROJECT: "dot_cifar"
DISTILLER:
TYPE: "KD"
TEACHER: "resnet32x4"
STUDENT: "resnet8x4"
SOLVER:
BATCH_SIZE: 64
EPOCHS: 240
LR: 0.05
LR_DECAY_STAGES: [150, 180, 210]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0005
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.075
20 changes: 20 additions & 0 deletions configs/cifar100/dot/res32x4_shuv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,res32x4,shuv2"
PROJECT: "dot_cifar"
DISTILLER:
TYPE: "KD"
TEACHER: "resnet32x4"
STUDENT: "ShuffleV2"
SOLVER:
BATCH_SIZE: 64
EPOCHS: 240
LR: 0.01
LR_DECAY_STAGES: [150, 180, 210]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0005
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.075
20 changes: 20 additions & 0 deletions configs/cifar100/dot/vgg13_vgg8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,vgg13,vgg8"
PROJECT: "dot_cifar"
DISTILLER:
TYPE: "KD"
TEACHER: "vgg13"
STUDENT: "vgg8"
SOLVER:
BATCH_SIZE: 64
EPOCHS: 240
LR: 0.05
LR_DECAY_STAGES: [150, 180, 210]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0005
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.075
33 changes: 33 additions & 0 deletions configs/imagenet/r34_r18/dot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,res34,res18"
PROJECT: "dot_imagenet"
DATASET:
TYPE: "imagenet"
NUM_WORKERS: 32
TEST:
BATCH_SIZE: 128
DISTILLER:
TYPE: "KD"
TEACHER: "ResNet34"
STUDENT: "ResNet18"
SOLVER:
BATCH_SIZE: 512
EPOCHS: 100
LR: 0.2
LR_DECAY_STAGES: [30, 60, 90]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0001
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.09
KD:
TEMPERATURE: 1
LOSS:
CE_WEIGHT: 0.5
KD_WEIGHT: 0.5
LOG:
TENSORBOARD_FREQ: 50
SAVE_CHECKPOINT_FREQ: 10
33 changes: 33 additions & 0 deletions configs/imagenet/r50_mv1/dot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,res50,mobilenetv1"
PROJECT: "dot_imagenet"
DATASET:
TYPE: "imagenet"
NUM_WORKERS: 32
TEST:
BATCH_SIZE: 128
DISTILLER:
TYPE: "KD"
TEACHER: "ResNet50"
STUDENT: "MobileNetV1"
SOLVER:
BATCH_SIZE: 512
EPOCHS: 100
LR: 0.2
LR_DECAY_STAGES: [30, 60, 90]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0001
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.09
KD:
TEMPERATURE: 1
LOSS:
CE_WEIGHT: 0.5
KD_WEIGHT: 0.5
LOG:
TENSORBOARD_FREQ: 50
SAVE_CHECKPOINT_FREQ: 10
23 changes: 23 additions & 0 deletions configs/tiny_imagenet/dot/r18_mv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,r18,mv2"
PROJECT: "dot_tinyimagenet"
DATASET:
TYPE: "tiny_imagenet"
NUM_WORKERS: 16
DISTILLER:
TYPE: "KD"
TEACHER: "ResNet18"
STUDENT: "MobileNetV2"
SOLVER:
BATCH_SIZE: 256
EPOCHS: 200
LR: 0.2
LR_DECAY_STAGES: [60, 120, 160]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0005
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.075
23 changes: 23 additions & 0 deletions configs/tiny_imagenet/dot/r18_shuv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
EXPERIMENT:
NAME: ""
TAG: "kd,dot,r18,shuv2"
PROJECT: "dot_tinyimagenet"
DATASET:
TYPE: "tiny_imagenet"
NUM_WORKERS: 16
DISTILLER:
TYPE: "KD"
TEACHER: "ResNet18"
STUDENT: "ShuffleV2"
SOLVER:
BATCH_SIZE: 256
EPOCHS: 200
LR: 0.2
LR_DECAY_STAGES: [60, 120, 160]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0005
MOMENTUM: 0.9
TYPE: "SGD"
TRAINER: "dot"
DOT:
DELTA: 0.075
16 changes: 16 additions & 0 deletions mdistiller/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample
from .imagenet import get_imagenet_dataloaders, get_imagenet_dataloaders_sample
from .tiny_imagenet import get_tinyimagenet_dataloader, get_tinyimagenet_dataloader_sample


def get_dataset(cfg):
Expand Down Expand Up @@ -34,6 +35,21 @@ def get_dataset(cfg):
num_workers=cfg.DATASET.NUM_WORKERS,
)
num_classes = 1000
elif cfg.DATASET.TYPE == "tiny_imagenet":
if cfg.DISTILLER.TYPE in ("CRD", "CRDKD"):
train_loader, val_loader, num_data = get_tinyimagenet_dataloader_sample(
batch_size=cfg.SOLVER.BATCH_SIZE,
val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
num_workers=cfg.DATASET.NUM_WORKERS,
k=cfg.CRD.NCE.K,
)
else:
train_loader, val_loader, num_data = get_tinyimagenet_dataloader(
batch_size=cfg.SOLVER.BATCH_SIZE,
val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
num_workers=cfg.DATASET.NUM_WORKERS,
)
num_classes = 200
else:
raise NotImplementedError(cfg.DATASET.TYPE)

Expand Down
122 changes: 122 additions & 0 deletions mdistiller/dataset/tiny_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import numpy as np


data_folder = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../data/tiny-imagenet-200"
)


class ImageFolderInstance(datasets.ImageFolder):
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img, target, index


class ImageFolderInstanceSample(ImageFolderInstance):
""": Folder datasets which returns (img, label, index, contrast_index):
"""
def __init__(self, folder, transform=None, target_transform=None,
is_sample=False, k=4096):
super().__init__(folder, transform=transform)

self.k = k
self.is_sample = is_sample
if self.is_sample:
num_classes = 200
num_samples = len(self.samples)
label = np.zeros(num_samples, dtype=np.int32)
for i in range(num_samples):
img, target = self.samples[i]
label[i] = target

self.cls_positive = [[] for i in range(num_classes)]
for i in range(num_samples):
self.cls_positive[label[i]].append(i)

self.cls_negative = [[] for i in range(num_classes)]
for i in range(num_classes):
for j in range(num_classes):
if j == i:
continue
self.cls_negative[i].extend(self.cls_positive[j])

self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)]
self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)]
print('dataset initialized!')

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
img, target, index = super().__getitem__(index)

if self.is_sample:
# sample contrastive examples
pos_idx = index
neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True)
sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
return img, target, index, sample_idx
else:
return img, target, index


def get_tinyimagenet_dataloader(batch_size, val_batch_size, num_workers):
"""Data Loader for tiny-imagenet"""
train_transform = transforms.Compose([
transforms.RandomRotation(20),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
])
train_folder = os.path.join(data_folder, "train")
test_folder = os.path.join(data_folder, "val")
train_set = ImageFolderInstance(train_folder, transform=train_transform)
num_data = len(train_set)
test_set = datasets.ImageFolder(test_folder, transform=test_transform)
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_loader = DataLoader(
test_set, batch_size=val_batch_size, shuffle=False, num_workers=1
)
return train_loader, test_loader, num_data


def get_tinyimagenet_dataloader_sample(batch_size, val_batch_size, num_workers, k):
"""Data Loader for tiny-imagenet"""
train_transform = transforms.Compose([
transforms.RandomRotation(20),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
])
train_folder = os.path.join(data_folder, "train")
test_folder = os.path.join(data_folder, "val")
train_set = ImageFolderInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k)
num_data = len(train_set)
test_set = datasets.ImageFolder(test_folder, transform=test_transform)
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_loader = DataLoader(
test_set, batch_size=val_batch_size, shuffle=False, num_workers=1
)
return train_loader, test_loader, num_data
5 changes: 3 additions & 2 deletions mdistiller/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .trainer import BaseTrainer, CRDTrainer

from .trainer import BaseTrainer, CRDTrainer, DOT, CRDDOT
trainer_dict = {
"base": BaseTrainer,
"crd": CRDTrainer,
"dot": DOT,
"crd_dot": CRDDOT,
}
Loading

0 comments on commit a08d46f

Please sign in to comment.