-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_synapse.py
125 lines (108 loc) · 6.07 KB
/
train_synapse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from lib.networks import EMCADNet
from trainer import trainer_synapse
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='./data/synapse/train_npz', help='root dir for data')
parser.add_argument('--volume_path', type=str,
default='./data/synapse/test_vol_h5', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
# network related parameters
parser.add_argument('--encoder', type=str,
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
parser.add_argument('--expansion_factor', type=int,
default=2, help='expansion factor in MSCB block')
parser.add_argument('--kernel_sizes', type=int, nargs='+',
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
parser.add_argument('--lgag_ks', type=int,
default=3, help='Kernel size in LGAG')
parser.add_argument('--activation_mscb', type=str,
default='relu6', help='activation used in MSCB: relu6 or relu')
parser.add_argument('--no_dw_parallel', action='store_true',
default=False, help='use this flag to disable depth-wise parallel convolutions')
parser.add_argument('--concatenation', action='store_true',
default=False, help='use this flag to concatenate feature maps in MSDC block')
parser.add_argument('--no_pretrain', action='store_true',
default=False, help='use this flag to turn off loading pretrained enocder weights')
parser.add_argument('--supervision', type=str,
default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')
parser.add_argument('--max_iterations', type=int,
default=50000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=300, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=6, help='batch_size per gpu')
parser.add_argument('--base_lr', type=float, default=0.0001,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--seed', type=int,
default=2222, help='random seed')
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'volume_path': args.volume_path,
'list_dir': args.list_dir,
'num_classes': args.num_classes,
'z_spacing': 1,
},
}
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.volume_path = dataset_config[dataset_name]['volume_path']
args.z_spacing = dataset_config[dataset_name]['z_spacing']
args.list_dir = dataset_config[dataset_name]['list_dir']
if args.concatenation:
aggregation = 'concat'
else:
aggregation = 'add'
if args.no_dw_parallel:
dw_mode = 'series'
else:
dw_mode = 'parallel'
run = 1
args.exp = args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)
snapshot_path = "model_pth/{}/{}".format(args.exp, args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run))
snapshot_path = snapshot_path.replace('[', '').replace(']', '').replace(', ', '_')
snapshot_path = snapshot_path + '_pretrain' if not args.no_pretrain else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 50000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 300 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.0001 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain)
model.cuda()
print('Model successfully created.')
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, model, snapshot_path)