-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathaverage_fusion.py
44 lines (33 loc) · 1.41 KB
/
average_fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import dataloader
from utils.utils import *
if __name__ == '__main__':
rgb_preds = 'record/spatial/spatial_video_preds.pickle'
opf_preds = 'record/motion/motion_video_preds.pickle'
with open(rgb_preds, 'rb') as f:
rgb = pickle.load(f)
f.close()
with open(opf_preds, 'rb') as f:
opf = pickle.load(f)
f.close()
dataloader = dataloader.spatial_dataloader(BATCH_SIZE=1, num_workers=1,
path='/home/ubuntu/data/UCF101/spatial_no_sampled/',
ucf_list='/home/ubuntu/cvlab/pytorch/ucf101_two_stream/github/UCF_data_references/',
ucf_split='01')
train_loader, val_loader, test_video = dataloader.run()
video_level_preds = np.zeros((len(rgb.keys()), 101))
video_level_labels = np.zeros(len(rgb.keys()))
correct = 0
ii = 0
for name in sorted(rgb.keys()):
r = rgb[name]
o = opf[name]
label = int(test_video[name]) - 1
video_level_preds[ii, :] = (r + o)
video_level_labels[ii] = label
ii += 1
if np.argmax(r + o) == (label):
correct += 1
video_level_labels = torch.from_numpy(video_level_labels).long()
video_level_preds = torch.from_numpy(video_level_preds).float()
top1, top5 = accuracy(video_level_preds, video_level_labels, topk=(1, 5))
print top1, top5