-
Notifications
You must be signed in to change notification settings - Fork 0
/
export.py
80 lines (71 loc) · 2.06 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Exports PyTorch's state dict into binary file,
doesn't help in reducing model size
but makes the model readable for numpy.
"""
from typing import Dict
import torch
import argparse
import os
from hp import hp
import numpy as np
def main(args):
assert os.path.isfile(args.model)
if os.path.isfile(args.output):
print(f"{args.output} exists, overwrite? [y/N]")
if input() != str("y"):
return
f = open(args.output, "wb")
# write meta data first
# use `\0`` for splitting symbols, hope no one uses it in phone set
phone_set_bytes = bytearray(
"\0".join(hp["phone_set"]), "ascii"
) # FIXME: only ascii is supported
meta_data = np.ascontiguousarray(
np.array(
[
hp["n_fft"],
hp["hop_size"],
hp["win_size"],
hp["n_mels"],
hp["hid_dim"],
hp["phone_dim"],
hp["sr"],
len(phone_set_bytes), # phone_set_bytes_len in aligner
],
dtype=np.int32,
)
)
f.write(memoryview(meta_data)) # type: ignore
f.write(memoryview(phone_set_bytes))
state_dict: Dict[str, torch.Tensor] = torch.load(args.model, map_location="cpu")
for k, v in state_dict.items():
print(k, "-", v.shape)
t = v.contiguous().view(-1).cpu().detach().type(torch.float32).numpy()
f.write(memoryview(t))
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Exports PyTorch's state dict into binary file,
doesn't help in reducing model size
but makes the model readable for numpy.
"""
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="PyTorch's `.pth` state dict file",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=False,
default="model.bin",
help="Name of the output `.bin` file",
)
args = parser.parse_args()
main(args)