-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
127 lines (106 loc) · 4.52 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# ignore tensorflow depreciate warnings
# see: https://github.com/tensorflow/tensorflow/issues/27045#issuecomment-480691244
import tensorflow as tf
if type(tf.contrib) != type(tf): tf.contrib._warning = None
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import os
import sys
sys.path.append('waveglow/')
import numpy as np
import time
import torch
import argparse
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT, STFT
from audio_processing import griffin_lim
from train import load_model
from text import text_to_sequence
from scipy.io.wavfile import write
from waveglow.denoiser import Denoiser
from waveglow.mel2samp import files_to_list, MAX_WAV_VALUE
import textwrap
def make_space_above(axes, topmargin=1): # see https://stackoverflow.com/a/55768955
""" increase figure size to make topmargin (in inches) space for
titles, without changing the axes sizes"""
fig = axes.flatten()[0].figure
s = fig.subplotpars
w, h = fig.get_size_inches()
figh = h - (1-s.top)*h + topmargin
fig.subplots_adjust(bottom=s.bottom*h/figh, top=1-topmargin/figh)
fig.set_figheight(figh)
def plot_data(data, transcript, image_path, figsize=(11, 4)):
print("plot results...")
fig, axes = plt.subplots(1, len(data), figsize=figsize)
fig_names = ['output', 'alignment']
for i in range(len(data)):
axes[i].imshow(data[i], aspect='auto', origin='bottom',
interpolation='none')
axes[i].set_xlabel(fig_names[i])
plt.suptitle("\n".join(textwrap.wrap(transcript, 100))) # see https://stackoverflow.com/a/55768955
make_space_above(axes, topmargin=1)
plt.savefig(image_path)
print("All plots saved!: %s" % image_path)
plt.close()
def synthesize(hparams, model, waveglow, outdir, transcript, filename):
sequence = np.array(text_to_sequence(transcript, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(
torch.from_numpy(sequence)).cuda().long()
with torch.no_grad():
output_mel_path = os.path.join(
outdir, "{}.png".format(filename))
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
plot_data((mel_outputs_postnet.float().data.cpu().numpy()[0],
alignments.float().data.cpu().numpy()[0].T),
transcript,
output_mel_path)
print("infer audio...")
audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio = audio * MAX_WAV_VALUE
audio = audio.squeeze()
audio = audio.cpu().numpy()
audio = audio.astype('int16')
audio_path = os.path.join(
outdir, "{}.wav".format(filename))
write(audio_path, hparams.sampling_rate, audio)
# print("Synthesized audio saved!: %s" % audio_path)
print("\n")
def load_models(hparams, checkpoint_path, waveglow_path):
print("load models...")
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
model.cuda().eval()
waveglow = torch.load(waveglow_path)['model']
waveglow.cuda().eval()
for k in waveglow.convinv:
k.float()
print("loaded!")
return model, waveglow
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--checkpoint_path', type=str,
required=True, help='checkpoint path')
parser.add_argument('-t', '--text', type=str,
required=True, help='text to synthesize')
parser.add_argument('-w', '--waveglow_path', type=str,
required=False, help='waveglow path',
default='/home/keon/contextron/pretrained_models/waveglow_256channels_universal_v5.pt')
args = parser.parse_args()
transcript = args.text
checkpoint_path = args.checkpoint_path
waveglow_path = args.waveglow_path
assert os.path.isfile(checkpoint_path), "No such checkpoint"
assert os.path.isfile(waveglow_path), "No such waveglow"
hparams = create_hparams()
hparams.sampling_rate = 22050
outdir = os.path.join("results", "single_inference")
hash_ = '{0:010x}'.format(int(time.time() * 256))[:10]
filename = hash_ + '_' + checkpoint_path.split('/')[-1]
if not os.path.isdir(outdir):
os.makedirs(outdir)
model, waveglow = load_models(hparams, checkpoint_path, waveglow_path)
print("------------- inference -------------")
print("input text: \n%s" % transcript)
synthesize(hparams, model, waveglow, outdir, transcript, filename)