From fcf888cb467c4a79414448f2bebb42f73b7c6c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=94=D1=8D=D0=BB=D0=B3=D1=8D=D1=80=D0=B4=D0=B0=D0=BB?= =?UTF-8?q?=D0=B0=D0=B8=CC=86=20=D0=A1=D2=AF=D1=85=D0=B1=D0=B0=D0=B0=D1=82?= =?UTF-8?q?=D0=B0=D1=80?= Date: Thu, 14 Feb 2019 10:29:49 +0800 Subject: [PATCH] Upsample parameters read from data_config --- glow.py | 4 ++-- train.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/glow.py b/glow.py index fc3a374..9e0ec01 100644 --- a/glow.py +++ b/glow.py @@ -177,12 +177,12 @@ def forward(self, forward_input): class WaveGlow(torch.nn.Module): def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, - n_early_size, WN_config): + n_early_size, WN_config, win_length, hop_length): super(WaveGlow, self).__init__() self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, n_mel_channels, - 1024, stride=256) + win_length, stride=hop_length) assert(n_group % 2 == 0) self.n_flows = n_flows self.n_group = n_group diff --git a/train.py b/train.py index 5d50b9b..74b7fbc 100644 --- a/train.py +++ b/train.py @@ -154,7 +154,11 @@ def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, global dist_config dist_config = config["dist_config"] global waveglow_config - waveglow_config = config["waveglow_config"] + waveglow_config = { + **config["waveglow_config"], + 'win_length': data_config['win_length'], + 'hop_length': data_config['hop_length'] + } num_gpus = torch.cuda.device_count() if num_gpus > 1: