-
Notifications
You must be signed in to change notification settings - Fork 2
/
_mnist_ddp.py
52 lines (46 loc) · 1.41 KB
/
_mnist_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import torch
import numpy as np
from models import MNIST_DDP
net = MNIST_DDP().cuda()
from ddp import DDPNOPT
opt = DDPNOPT(net, lr=1e-4, lrddp=1e-4)
ce = nn.CrossEntropyLoss()
from sklearn.metrics import f1_score
from tqdm.auto import tqdm
from dataloaders import MNIST_Data
device = 'cuda'
datatrain = MNIST_Data(True)
train_loader = DataLoader(datatrain, batch_size=32, shuffle=True)
dataval = MNIST_Data(False)
val_loader = DataLoader(dataval, batch_size=32)
for epoch in range(100):
print('Epoch',epoch+1)
torch.cuda.empty_cache()
train_iter = iter(train_loader)
net.train()
for i in tqdm(range(len(train_loader))):
opt.zero_grad()
x, y = next(train_iter)
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
logit = net(x)
loss = ce(logit, y)
loss.backward(retain_graph=True)
opt.step()
net.eval()
gts, preds = [], []
val_iter = iter(val_loader)
for i in range(len(val_loader)):
x, y = next(val_iter)
x = x.to(device, non_blocking=True)
with torch.no_grad(): logit = net(x)
pred = logit.argmax(axis=-1)
y = list(y.numpy())
pred = list(pred.cpu().detach().numpy())
gts+=y
preds+=pred
print(f1_score(gts,preds,average='macro'))