forked from zijundeng/pytorch-semantic-segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
62 lines (47 loc) · 2.08 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
from torch.autograd import Variable
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import LSUN
from datasets.cityscapes.config import num_classes
from datasets.cityscapes.utils import colorize_mask
from config import ckpt_path, test_results_path
from models import PSPNet
import utils.transforms as expanded_transform
cudnn.benchmark = True
def main():
batch_size = 8
net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda()
snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth'
net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
net.eval()
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transform = transforms.Compose([
expanded_transform.FreeScale((512, 1024)),
transforms.ToTensor(),
transforms.Normalize(*mean_std)
])
restore = transforms.Compose([
expanded_transform.DeNormalize(*mean_std),
transforms.ToPILImage()
])
lsun_path = '/home/b3-542/LSUN'
dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True)
if not os.path.exists(test_results_path):
os.mkdir(test_results_path)
for vi, data in enumerate(dataloader, 0):
inputs, labels = data
inputs = Variable(inputs, volatile=True).cuda()
outputs = net(inputs)
prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy()
for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)):
pil_input = restore(tensor[0])
pil_output = colorize_mask(tensor[1])
pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx)))
pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx)))
print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
if __name__ == '__main__':
main()