forked from ikhlestov/vision_networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_dense_net.py
147 lines (131 loc) · 5.38 KB
/
run_dense_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
from models.dense_net import DenseNet
from data_providers.utils import get_data_provider_by_name
train_params_cifar = {
'batch_size': 64,
'n_epochs': 300,
'initial_learning_rate': 0.1,
'reduce_lr_epoch_1': 150, # epochs * 0.5
'reduce_lr_epoch_2': 225, # epochs * 0.75
'validation_set': True,
'validation_split': None, # None or float
'shuffle': 'every_epoch', # None, once_prior_train, every_epoch
'normalization': 'by_chanels', # None, divide_256, divide_255, by_chanels
}
train_params_svhn = {
'batch_size': 64,
'n_epochs': 40,
'initial_learning_rate': 0.1,
'reduce_lr_epoch_1': 20,
'reduce_lr_epoch_2': 30,
'validation_set': True,
'validation_split': None, # you may set it 6000 as in the paper
'shuffle': True, # shuffle dataset every epoch or not
'normalization': 'divide_255',
}
def get_train_params_by_name(name):
if name in ['C10', 'C10+', 'C100', 'C100+']:
return train_params_cifar
if name == 'SVHN':
return train_params_svhn
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--train', action='store_true',
help='Train the model')
parser.add_argument(
'--test', action='store_true',
help='Test model for required dataset if pretrained model exists.'
'If provided together with `--train` flag testing will be'
'performed right after training.')
parser.add_argument(
'--model_type', '-m', type=str, choices=['DenseNet', 'DenseNet-BC'],
default='DenseNet',
help='What type of model to use')
parser.add_argument(
'--growth_rate', '-k', type=int, choices=[12, 24, 40],
default=12,
help='Grows rate for every layer, '
'choices were restricted to used in paper')
parser.add_argument(
'--depth', '-d', type=int, choices=[40, 100, 190, 250],
default=40,
help='Depth of whole network, restricted to paper choices')
parser.add_argument(
'--dataset', '-ds', type=str,
choices=['C10', 'C10+', 'C100', 'C100+', 'SVHN'],
default='C10',
help='What dataset should be used')
parser.add_argument(
'--total_blocks', '-tb', type=int, default=3, metavar='',
help='Total blocks of layers stack (default: %(default)s)')
parser.add_argument(
'--keep_prob', '-kp', type=float, metavar='',
help="Keep probability for dropout.")
parser.add_argument(
'--weight_decay', '-wd', type=float, default=1e-4, metavar='',
help='Weight decay for optimizer (default: %(default)s)')
parser.add_argument(
'--nesterov_momentum', '-nm', type=float, default=0.9, metavar='',
help='Nesterov momentum (default: %(default)s)')
parser.add_argument(
'--reduction', '-red', type=float, default=0.5, metavar='',
help='reduction Theta at transition layer for DenseNets-BC models')
parser.add_argument(
'--logs', dest='should_save_logs', action='store_true',
help='Write tensorflow logs')
parser.add_argument(
'--no-logs', dest='should_save_logs', action='store_false',
help='Do not write tensorflow logs')
parser.set_defaults(should_save_logs=True)
parser.add_argument(
'--saves', dest='should_save_model', action='store_true',
help='Save model during training')
parser.add_argument(
'--no-saves', dest='should_save_model', action='store_false',
help='Do not save model during training')
parser.set_defaults(should_save_model=True)
parser.add_argument(
'--renew-logs', dest='renew_logs', action='store_true',
help='Erase previous logs for model if exists.')
parser.add_argument(
'--not-renew-logs', dest='renew_logs', action='store_false',
help='Do not erase previous logs for model if exists.')
parser.set_defaults(renew_logs=True)
args = parser.parse_args()
if not args.keep_prob:
if args.dataset in ['C10', 'C100', 'SVHN']:
args.keep_prob = 0.8
else:
args.keep_prob = 1.0
if args.model_type == 'DenseNet':
args.bc_mode = False
args.reduction = 1.0
elif args.model_type == 'DenseNet-BC':
args.bc_mode = True
model_params = vars(args)
if not args.train and not args.test:
print("You should train or test your network. Please check params.")
exit()
# some default params dataset/architecture related
train_params = get_train_params_by_name(args.dataset)
print("Params:")
for k, v in model_params.items():
print("\t%s: %s" % (k, v))
print("Train params:")
for k, v in train_params.items():
print("\t%s: %s" % (k, v))
print("Prepare training data...")
data_provider = get_data_provider_by_name(args.dataset, train_params)
print("Initialize the model..")
model = DenseNet(data_provider=data_provider, **model_params)
if args.train:
print("Data provider train images: ", data_provider.train.num_examples)
model.train_all_epochs(train_params)
if args.test:
if not args.train:
model.load_model()
print("Data provider test images: ", data_provider.test.num_examples)
print("Testing...")
loss, accuracy = model.test(data_provider.test, batch_size=200)
print("mean cross_entropy: %f, mean accuracy: %f" % (loss, accuracy))