-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_train.py
75 lines (58 loc) · 3.09 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
parser = argparse.ArgumentParser(description='Train model.')
parser.add_argument('--model-name', help='Name of model to train.', choices=['cenet', 'deeplabv3plus', 'unetpp'], required=True)
parser.add_argument('--name-csv-train', help='Name of the CSV file with training dataset information.', required=True)
parser.add_argument('--data-dir', help='Path to the folder with the CSV files and image subfolders.', required=True)
parser.add_argument('--path-save', help='Path to the folder where model will be saved.', required=True)
parser.add_argument('--img-size', type=int, help='Size to which the images should be reshaped (one number, i.e. 256 or 512).', required=True)
parser.add_argument('--batch-size', type=int, help='Batch size for the model during training.', default=4)
parser.add_argument('--binary', type=bool, help='Whether the segmentation masks are binary (True) or multi-class (False).', default=False)
args = parser.parse_args()
import os
import os.path as osp
import matplotlib.pyplot as plt
import pandas as pd
from models.cenet import CENet
from models.deeplabv3plus import DeepLabV3Plus
from models.unetpp import UnetPlusPlus
from utils.data_utils import *
train_path = osp.join(args.data_dir, args.name_csv_train)
img_size = (args.img_size, args.img_size)
dataset_dir = args.data_dir
df_train = pd.read_csv(train_path)[['imageID', 'imageDIR', 'segDIR']].values.tolist()
train_paths = []
for r in df_train:
img_path = osp.join(osp.split(dataset_dir)[0], r[1], r[0]).replace('\\', '/')
mask_path = osp.join(osp.split(dataset_dir)[0], r[2], r[0]).replace('\\', '/')
train_paths.append((img_path, mask_path))
model = {
'cenet': CENet,
'deeplabv3plus': DeepLabV3Plus,
'unetpp': UnetPlusPlus
}[args.model_name]((img_size[0],img_size[1],3), 2 if args.binary else 3) # only important for unet models, SOTA models have their own size/n_channels and this will be disregarded
torch_models = ['cenet']
polar_models = ['mnet']
val_size=0.1
train_gen, val_gen, _ = get_gens(img_size, train_paths, [], args.batch_size, val_size=val_size, binary=args.binary, polar=(args.model_name in polar_models), channelsFirst=(args.model_name in torch_models))
train_len = int(len(train_paths)*(1-val_size))
val_len = len(train_paths) - train_len
# models needing extra config
if args.model_name == 'attnet':
model.set_config_params(args.batch_size, train_len, val_len)
history = model.train(train_gen, val_gen, train_len//args.batch_size, val_len//args.batch_size)
sp = args.path_save
if not os.path.isdir(sp): os.makedirs(sp)
model.save(osp.join(sp,f'{args.model_name}_model' + ['.h5', '.pth'][args.model_name in torch_models]))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model Loss During Training')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
# plt.show()
plt.savefig(osp.join(sp,f'{args.model_name}_loss.png'))
out_write = open(osp.join(sp,'losses.csv'), 'w')
out_write.write('epoch,loss,val_loss\n')
for i in range(len(history['loss'])):
n = out_write.write(f"{i},{history['loss'][i]},{history['val_loss'][i]}\n")
out_write.close()