-
Notifications
You must be signed in to change notification settings - Fork 86
/
batchnormfuser.py
executable file
·113 lines (95 loc) · 3.63 KB
/
batchnormfuser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
###################################################################################################
#
# Copyright (C) 2020-2023 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
Script that is used for fusing/folding batchnorm layers onto conv2d layers.
"""
import argparse
import torch
def bn_fuser(state_dict):
"""
Fuses the BN parameters and returns a new statedict
"""
dict_keys = state_dict.keys()
set_convbn_layers = set()
for dict_key in dict_keys:
if dict_key.endswith('.bn.running_mean'):
set_convbn_layers.add(dict_key.rsplit('.', 3)[0])
for layer in set_convbn_layers:
if layer + '.op.weight' in state_dict:
conv_key = layer + '.op'
else:
conv_key = layer + '.conv2d' # Compatibility with older checkpoints
w_key = conv_key + '.weight'
b_key = conv_key + '.bias'
bn_key = layer + '.bn'
r_mean_key = bn_key + '.running_mean'
r_var_key = bn_key + '.running_var'
beta_key = bn_key + '.weight'
gamma_key = bn_key + '.bias'
batches_key = bn_key + '.num_batches_tracked'
w = state_dict[w_key]
device = state_dict[w_key].device
if b_key in state_dict:
b = state_dict[b_key]
else:
b = torch.zeros(w.shape[0], device=device)
if r_mean_key in state_dict:
r_mean = state_dict[r_mean_key]
if r_var_key in state_dict:
r_var = state_dict[r_var_key]
r_std = torch.sqrt(r_var + 1e-20)
if beta_key in state_dict:
beta = state_dict[beta_key]
else:
beta = torch.ones(w.shape[0], device=device)
if gamma_key in state_dict:
gamma = state_dict[gamma_key]
else:
gamma = torch.zeros(w.shape[0], device=device)
beta = 0.25 * beta
gamma = 0.25 * gamma
w_new = w * (beta / r_std).reshape((w.shape[0],) + (1,) * (len(w.shape) - 1))
b_new = (b - r_mean)/r_std * beta + gamma
state_dict[w_key] = w_new
state_dict[b_key] = b_new
if r_mean_key in state_dict:
del state_dict[r_mean_key]
if r_var_key in state_dict:
del state_dict[r_var_key]
if beta_key in state_dict:
del state_dict[beta_key]
if gamma_key in state_dict:
del state_dict[gamma_key]
if batches_key in state_dict:
del state_dict[batches_key]
return state_dict
def main(args):
"""
Main function
"""
inp_path = args.inp_path
out_path = args.out_path
out_arch = args.out_arch
model_params = torch.load(inp_path)
new_state_dict = bn_fuser(model_params['state_dict'])
model_params['state_dict'] = new_state_dict
model_params['arch'] = out_arch
torch.save(model_params, out_path)
print(f'New checkpoint is saved to: {out_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--inp_path', type=str, required=True, default=12,
help='Input checkpoint path')
parser.add_argument('-o', '--out_path', type=str, required=True, default=20,
help='Fused output checkpoint path')
parser.add_argument('-oa', '--out_arch', type=str, required=True, default=20,
help='Output arch name')
arguments = parser.parse_args()
main(arguments)