Skip to content

Commit

Permalink
take care of automatic audio -> mel if mel is not passed in, and doin…
Browse files Browse the repository at this point in the history
…g conditional training
  • Loading branch information
lucidrains committed Aug 30, 2023
1 parent e74ef26 commit 82f09e6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 16 deletions.
102 changes: 87 additions & 15 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.utils.data import Dataset, DataLoader

import torchaudio
import torchaudio.transforms as T

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
Expand All @@ -38,7 +39,6 @@
from tqdm.auto import tqdm
import pyworld as pw


# constants

mlist = nn.ModuleList
Expand Down Expand Up @@ -127,6 +127,53 @@ def f0_to_coarse(f0, f0_bin = 256, f0_max = 1100.0, f0_min = 50.0):

# peripheral models

# audio to mel

class AudioToMel(nn.Module):
def __init__(
self,
*,
n_mels = 100,
sampling_rate = 24000,
f_max = 8000,
n_fft = 1024,
win_length = 640,
hop_length = 160,
log = True
):
super().__init__()
self.log = log
self.n_mels = n_mels
self.n_fft = n_fft
self.f_max = f_max
self.win_length = win_length
self.hop_length = hop_length
self.sampling_rate = sampling_rate

def forward(self, audio):
stft_transform = T.Spectrogram(
n_fft = self.n_fft,
win_length = self.win_length,
hop_length = self.hop_length,
window_fn = torch.hann_window
)

spectrogram = stft_transform(audio)

mel_transform = T.MelScale(
n_mels = self.n_mels,
sample_rate = self.sampling_rate,
n_stft = self.n_fft // 2 + 1,
f_max = self.f_max
)

mel = mel_transform(spectrogram)

if self.log:
mel = T.AmplitudeToDB()(mel)

return mel

# phoneme - pitch - speech prompt - duration predictors

class PhonemeEncoder(nn.Module):
Expand Down Expand Up @@ -1048,20 +1095,17 @@ def __init__(
aligner_dim_in: int = 80,
aligner_dim_hidden: int = 512,
aligner_attn_channels: int = 80,
num_phoneme_tokens: int = 150,
pitch_emb_dim: int = 256,
pitch_emb_pp_hidden_dim: int= 512,
pitch_emb_pp_hidden_dim: int= 512,
audio_to_mel_kwargs: dict = dict(),
scale = 1. # this will be set to < 1. for better convergence when training on higher resolution images
):
super().__init__()

self.conditional = model.condition_on_prompt

if self.conditional:
self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=150)
self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook)
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)
# model and codec

self.model = model
self.codec = codec
Expand All @@ -1075,6 +1119,25 @@ def __init__(
self.target_sample_hz = codec.target_sample_hz
self.seq_len_multiple_of = codec.seq_len_multiple_of

# preparation for conditioning

if self.conditional:
if exists(self.target_sample_hz):
audio_to_mel_kwargs.update(sampling_rate = self.target_sample_hz)

self.audio_to_mel = AudioToMel(
n_mels = aligner_dim_in,
**audio_to_mel_kwargs
)

self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=num_phoneme_tokens)
self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook)
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)

# rest of ddpm

assert not exists(codec) or model.dim == codec.codebook_dim, f'transformer model dimension {model.dim} must be equal to codec dimension {codec.codebook_dim}'

self.dim = codec.codebook_dim if exists(codec) else model.dim
Expand Down Expand Up @@ -1325,20 +1388,29 @@ def forward(

assert not (is_raw_audio and not exists(self.codec)), 'codec must be passed in if one were to train on raw audio'

if self.conditional:
assert exists(text) and exists(pitch) # eventually make pitch automatically computed if not passed in
text_max_length = text.shape[-1]

if not exists(mel):
assert is_raw_audio

mel = self.audio_to_mel(audio)
mel = mel[..., :text_max_length]

mel_max_length = mel.shape[-1]

if is_raw_audio:
with torch.no_grad():
self.codec.eval()
audio, codes, _ = self.codec(audio, return_encoded = True)

# compute the prompt encoding and cond

prompt_enc = None
cond = None

if self.conditional:
assert exists(mel) and exists(text) and exists(pitch)

mel_max_length = mel.shape[-1]
text_max_length = text.shape[-1]

if not exists(mel_lens):
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)

Expand All @@ -1347,7 +1419,7 @@ def forward(

mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')
text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')

prompt = self.process_prompt(prompt)
prompt_enc = self.prompt_enc(prompt)
phoneme_enc = self.phoneme_enc(text)
Expand Down Expand Up @@ -1378,7 +1450,7 @@ def forward(

# predict and take gradient step

pred = self.model(noised_audio, times, prompt = prompt_enc, cond=cond)
pred = self.model(noised_audio, times, prompt = prompt_enc, cond = cond)

if self.objective == 'eps':
target = noise
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.43'
__version__ = '0.0.44'

0 comments on commit 82f09e6

Please sign in to comment.