-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_pyhessian_analysis.py
121 lines (104 loc) · 4.67 KB
/
example_pyhessian_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""run.py:"""
#!/usr/bin/env python
from __future__ import print_function
import time
from collections import OrderedDict
import argparse
from utils import *
from models.resnet import resnet
from pyhessian import *
import torch.nn as nn
from density_plot import get_esd_plot
# Settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--mini-hessian-batch-size', type=int, default=200,
help='input batch size for mini-hessian batch (default: 200)')
parser.add_argument('--hessian-batch-size', type=int, default=200, help='input batch size for hessian (default: 200)')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--batch-norm', action='store_false', help='do we need batch norm or not')
parser.add_argument('--residual', action='store_false', help='do we need residual connect or not')
parser.add_argument('--cuda', action='store_false', help='do we use gpu or not')
parser.add_argument('--resume', type=str, required=True, help='get the checkpoint')
# eigen info
parser.add_argument('--eigenvalue', dest='eigenvalue', action='store_true',
help="to calculate top eigenvalue of the hessian") # default false
parser.add_argument('--trace', dest='trace', action='store_true',
help='to calculate trace of the hessian') # default false
parser.add_argument('--density', dest='density', action='store_true',
help='to calculate esd of the hessian') # default false
parser.set_defaults(eigenvalue=False)
parser.set_defaults(trace=False)
parser.set_defaults(density=False)
# for parallel computing
parser.add_argument('--ip', type=str, required=True, help='ip address of the machine for distributed computing')
parser.add_argument('--device_count', type=int, required=True, help='number of available devices')
args = parser.parse_args()
# set random seed to reproduce the work
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
for arg in vars(args):
print(arg, getattr(args, arg))
# # get dataset
train_loader, test_loader = getData(name='cifar10_without_data_augmentation',
train_bs=args.mini_hessian_batch_size,
test_bs=1)
##############
# Get the hessian data
##############
assert (args.hessian_batch_size % args.mini_hessian_batch_size == 0)
batch_num = args.hessian_batch_size // args.mini_hessian_batch_size
# getting the dataset batches
hessian_dataloader = []
for i, (inputs, labels) in enumerate(train_loader):
hessian_dataloader.append((inputs, labels))
if i == batch_num - 1:
break
# dividing dataset into partitions
assert len(hessian_dataloader) % args.device_count == 0, "Mini-batches must be uniformly divided among GPUs"
size = [len(hessian_dataloader) // args.device_count] * args.device_count
# partitioning data into number of GPUs available
data_partitions = DataPartitioner(hessian_dataloader, size)
# get model
model = resnet(num_classes=10,
depth=20,
residual_not=args.residual,
batch_norm_not=args.batch_norm)
# label loss
criterion = nn.CrossEntropyLoss()
###################
# Get model checkpoint, get saving folder
###################
if args.resume == '':
raise Exception("please choose the trained model")
# loading the state dictionary into the model
state_dict = torch.load(args.resume, map_location=torch.device('cpu'))
# since model was trained using DataParallel, have to remove 'module.'
# from state dictionary keys
state_dict_ = OrderedDict()
for key in state_dict.keys():
new_key = key[7:]
state_dict_[new_key] = state_dict[key]
model.load_state_dict(state_dict_)
######################################################
# Begin the computation
######################################################
if __name__ == "__main__":
if args.eigenvalue:
start = time.time()
top_eigenvalues = eigenvalue(args.device_count, model, data_partitions, criterion, args.ip)
end = time.time()
print('\n***Top Eigenvalues: ', top_eigenvalues)
print("Time to compute top eigenvalue: %f" % (end - start))
if args.trace:
start = time.time()
trace = trace(args.device_count, model, data_partitions, criterion, args.ip)
end = time.time()
print('\n***Trace: ', trace)
print("Time to compute trace: %f" % (end - start))
if args.density:
start = time.time()
density_eigen, density_weight = density(args.device_count, model, data_partitions, criterion, args.ip)
end = time.time()
get_esd_plot(density_eigen, density_weight)
print("Time to compute esd: %f" % (end - start))