-
Notifications
You must be signed in to change notification settings - Fork 216
/
train.py
executable file
·39 lines (30 loc) · 1.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python3
import os
import argparse
from networks.lenet import LeNet
from networks.pure_cnn import PureCnn
from networks.network_in_network import NetworkInNetwork
from networks.resnet import ResNet
from networks.densenet import DenseNet
from networks.wide_resnet import WideResNet
from networks.capsnet import CapsNet
if __name__ == '__main__':
models = {
'lenet': LeNet,
'pure_cnn': PureCnn,
'net_in_net': NetworkInNetwork,
'resnet': ResNet,
'densenet': DenseNet,
'wide_resnet': WideResNet,
'capsnet': CapsNet
}
parser = argparse.ArgumentParser(description='Train models on Cifar10')
parser.add_argument('--model', choices=models.keys(), required=True, help='Specify a model by name to train.')
parser.add_argument('--epochs', default=None, type=int)
parser.add_argument('--batch_size', default=None, type=int)
args = parser.parse_args()
model_name = args.model
args = {k: v for k, v in vars(args).items() if v != None}
del args['model']
model = models[model_name](**args, load_weights=False)
model.train()