-
Notifications
You must be signed in to change notification settings - Fork 2
/
util.py
105 lines (82 loc) · 2.98 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
""" General utility functions """
import torch
import torch.distributed as dist
from enum import Enum
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
with torch.no_grad():
batch_size = target.size(0)
_, pred = output.max(1)
correct = pred.eq(target).sum().item()
return correct * 100.0 / batch_size
def accuracy_ind(output, target, topk=(1,)):
with torch.no_grad():
batch_size = target.size(0)
_, pred = output.max(1)
correct = pred.eq(target)
return correct