-
Notifications
You must be signed in to change notification settings - Fork 63
/
train.py
executable file
·57 lines (44 loc) · 1.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from utils.utils import Logger
from utils.utils import save_checkpoint
from utils.utils import save_linear_checkpoint
from common.train import *
from evals import test_classifier
if 'sup' in P.mode:
from training.sup import setup
else:
from training.unsup import setup
train, fname = setup(P.mode, P)
logger = Logger(fname, ask=not resume, local_rank=P.local_rank)
logger.log(P)
logger.log(model)
if P.multi_gpu:
linear = model.module.linear
else:
linear = model.linear
linear_optim = torch.optim.Adam(linear.parameters(), lr=1e-3, betas=(.9, .999), weight_decay=P.weight_decay)
# Run experiments
for epoch in range(start_epoch, P.epochs + 1):
logger.log_dirname(f"Epoch {epoch}")
model.train()
if P.multi_gpu:
train_sampler.set_epoch(epoch)
kwargs = {}
kwargs['linear'] = linear
kwargs['linear_optim'] = linear_optim
kwargs['simclr_aug'] = simclr_aug
train(P, epoch, model, criterion, optimizer, scheduler_warmup, train_loader, logger=logger, **kwargs)
model.eval()
if epoch % P.save_step == 0 and P.local_rank == 0:
if P.multi_gpu:
save_states = model.module.state_dict()
else:
save_states = model.state_dict()
save_checkpoint(epoch, save_states, optimizer.state_dict(), logger.logdir)
save_linear_checkpoint(linear_optim.state_dict(), logger.logdir)
if epoch % P.error_step == 0 and ('sup' in P.mode):
error = test_classifier(P, model, test_loader, epoch, logger=logger)
is_best = (best > error)
if is_best:
best = error
logger.scalar_summary('eval/best_error', best, epoch)
logger.log('[Epoch %3d] [Test %5.2f] [Best %5.2f]' % (epoch, error, best))