-
Notifications
You must be signed in to change notification settings - Fork 7
/
identity.py
31 lines (26 loc) · 888 Bytes
/
identity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import torch
def main():
ckpt = torch.load(sys.argv[1])
lst = []
for k, v in ckpt['model'].items():
k_split = k.split('.')
if k_split[0] == 'encoder' and k_split[1] == 'layers':
id = int(k_split[2])
k_split[2] = str(id + ckpt['args'].encoder_layers)
new_k = '.'.join(k_split)
lst.append([new_k, v.clone()])
for k, v in lst:
k_split = k.split('.')
if k_split[-2] in ['fc2', 'out_proj']:
ckpt['model'][k] = torch.zeros_like(v)
elif k_split[-1].endswith('bias'):
ckpt['model'][k] = torch.zeros_like(v)
else:
# Kaiming normal
std = v.size(0) ** -0.5
ckpt['model'][k] = torch.randn_like(v) * std
ckpt['args'].encoder_layers *= 2
torch.save(ckpt, sys.argv[2])
if __name__ == '__main__':
main()