-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_state_dict.py
34 lines (29 loc) · 1.14 KB
/
convert_state_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
Script use to change the key names of state dicts so that it can be properly loaded in the evaluation code.
Kind of a hack, but it works...
"""
import sys
import torch
if len(sys.argv) < 2 or sys.argv[1] in ["-h", "--help"]:
print("Run script with 1 argument as follows: 'python3 convert_state_dict.py <filepath>'")
exit(0)
state_path = sys.argv[1]
new_pathname = state_path.split('.')
new_pathname = '.'.join(new_pathname[:-1]) + "_converted." + new_pathname[-1]
model_state = torch.load(state_path, map_location='cpu')['sd']
new_model_state = {}
for name, value in model_state.items():
model = "model"
name_parts = name.split('.')
if name_parts[0] == 'facet_lin_emb' or name_parts[0] == 'facet_pos_emb':
#print(name_parts)
start_idx = 0
else:
start_idx = 1
name = "sent_encoder._text_field_embedder." + model + "." + '.'.join(name.split('.')[start_idx:])
if name_parts[-1] == 'gamma':
name = '.'.join(name_parts[:-1]) + ".weight"
elif name_parts[-1] == 'beta':
name = '.'.join(name_parts[:-1]) + ".bias"
new_model_state[name] = value
torch.save(new_model_state, new_pathname)