-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
162 lines (134 loc) · 5.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import hydra
from omegaconf import DictConfig
from tqdm import tqdm
import torch
from torch import nn, optim
import torch.backends.cudnn as cudnn
from data_generator import get_data_loader
from data_preparation.verify_data import verify_data
from utils.general_utils import create_directory, join_paths, set_gpus, get_gpus_count
from models.model import prepare_model
from losses.loss import dice_coef
from losses.unet_loss import unet3p_hybrid_loss
def create_training_folders(cfg: DictConfig):
"""
Create directories to store Model CheckPoint and TensorBoard logs.
"""
create_directory(
join_paths(
cfg.WORK_DIR,
cfg.CALLBACKS.MODEL_CHECKPOINT.PATH
)
)
create_directory(
join_paths(
cfg.WORK_DIR,
cfg.CALLBACKS.LOGGING.PATH
)
)
def train(cfg: DictConfig):
"""
Training method
"""
print("Verifying data ...")
verify_data(cfg)
if cfg.USE_MULTI_GPUS.VALUE:
# change number of visible gpus for training
set_gpus(cfg.USE_MULTI_GPUS.GPU_IDS)
# change batch size according to available gpus
cfg.HYPER_PARAMETERS.BATCH_SIZE = \
cfg.HYPER_PARAMETERS.BATCH_SIZE * get_gpus_count()
# create folders to store training checkpoints and logs
create_training_folders(cfg)
# data generators
train_loader = get_data_loader(cfg, mode="TRAIN")
val_loader = get_data_loader(cfg, mode="VAL")
# set device with gpu id
gpu_id = cfg.GPU_ID
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
# create model
model = prepare_model(cfg, training=True).to(device)
if cfg.USE_MULTI_GPUS.VALUE:
model = nn.DataParallel(model)
# optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=cfg.HYPER_PARAMETERS.LEARNING_RATE)
criterion = unet3p_hybrid_loss
# EarlyStopping, ModelCheckpoint and CSVLogger callbacks
checkpoint_path = join_paths(
cfg.WORK_DIR,
cfg.CALLBACKS.MODEL_CHECKPOINT.PATH,
f"{cfg.MODEL.WEIGHTS_FILE_NAME}.pt"
)
print("Weights path\n" + checkpoint_path)
csv_log_path = join_paths(
cfg.WORK_DIR,
cfg.CALLBACKS.CSV_LOGGER.PATH,
f"training_logs_{cfg.MODEL.TYPE}.csv"
)
print("Logs path\n" + csv_log_path)
txt_log_path = join_paths(
cfg.WORK_DIR,
cfg.CALLBACKS.LOGGING.PATH,
f"training_logs_{cfg.MODEL.TYPE}.txt"
)
print("Text Logs path\n" + txt_log_path)
best_val_score = 0.0
for epoch in range(cfg.HYPER_PARAMETERS.EPOCHS):
model.train()
train_loss = 0.0
train_dice = 0.0
for batch_images, batch_masks in tqdm(train_loader):
batch_images, batch_masks = batch_images.to(device), batch_masks.to(device)
optimizer.zero_grad()
outputs = model(batch_images)
if model.deep_supervision:
batch_masks = batch_masks.repeat(5, 1, 1, 1)
loss = criterion(batch_masks, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item() * batch_images.size(0)
train_dice += dice_coef(outputs, batch_masks).item() * batch_images.size(0)
train_loss /= len(train_loader.dataset)
train_dice /= len(train_loader.dataset)
model.eval()
val_loss = 0.0
val_dice = 0.0
with torch.no_grad():
for batch_images, batch_masks in val_loader:
batch_images, batch_masks = batch_images.to(device), batch_masks.to(device)
outputs = model(batch_images)
loss = criterion(outputs, batch_masks)
val_loss += loss.item() * batch_images.size(0)
val_dice += dice_coef(outputs, batch_masks).item() * batch_images.size(0)
val_loss /= len(val_loader.dataset)
val_dice /= len(val_loader.dataset)
print(f"Epoch {epoch+1}/{cfg.HYPER_PARAMETERS.EPOCHS}, "
f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")
with open(txt_log_path, 'w') as log_file:
log_file.write(f"Epoch {epoch+1}/{cfg.HYPER_PARAMETERS.EPOCHS}, "
f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}\n")
# EarlyStopping and ModelCheckpoint logic
if val_dice > best_val_score:
best_val_score = val_dice
torch.save(model.state_dict(), checkpoint_path)
print("Saved best model")
with open(csv_log_path, 'a') as f:
f.write(f"{epoch+1},{train_loss},{train_dice},{val_loss},{val_dice}\n")
if val_dice > best_val_score - cfg.CALLBACKS.EARLY_STOPPING.DELTA:
patience_count = 0
else:
patience_count += 1
if patience_count >= cfg.CALLBACKS.EARLY_STOPPING.PATIENCE:
print("Early stopping")
break
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
"""
Read config file and pass to train method for training
"""
train(cfg)
if __name__ == "__main__":
main()