diff --git a/README.md b/README.md index 316d8b7..afb9d8b 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ $ pip install naturalspeech2-pytorch import torch from naturalspeech2_pytorch import ( EncodecWrapper, - Transformer, + Model, NaturalSpeech2 ) @@ -63,6 +63,56 @@ generated_audio = diffusion.sample(length = 1024) # (1, 327680) ``` +To accept prompting, you will need to instantiate `SpeechPromptEncoder` and pass it into `speech_prompt_encoder` on the `Model` + +ex. + +```python +import torch +from naturalspeech2_pytorch import ( + EncodecWrapper, + Model, + NaturalSpeech2, + SpeechPromptEncoder +) + +# use encodec as an example + +codec = EncodecWrapper() + +prompt_encoder = SpeechPromptEncoder( + dim_codebook = codec.codebook_dim, + depth = 2 +) + +model = Model( + dim = 128, + depth = 6, + speech_prompt_encoder = prompt_encoder # pass in the SpeechPromptEncoder +) + +# natural speech diffusion model + +diffusion = NaturalSpeech2( + model = model, + codec = codec, + timesteps = 1000 +).cuda() + +# mock raw audio data + +raw_audio = torch.randn(4, 327680).cuda() +prompt = torch.randn(4, 32768).cuda() # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically + +loss = diffusion(raw_audio, prompt = prompt) # pass in the prompt +loss.backward() + +# do the above in a loop for a lot of raw audio data... +# then you can sample from your generative model as so + +generated_audio = diffusion.sample(length = 1024, prompt = prompt) # pass in your prompt +``` + Or if you want a `Trainer` class to take care of the training and sampling loop, just simply do ```python @@ -78,6 +128,13 @@ trainer = Trainer( trainer.train() ``` +## Todo + +- [ ] complete perceiver then cross attention conditioning on ddpm side +- [ ] add classifier free guidance, even if not in paper +- [ ] add self-conditioning on ddpm side +- [ ] complete duration / pitch prediction during training + ## Citations ```bibtex diff --git a/naturalspeech2_pytorch/naturalspeech2_pytorch.py b/naturalspeech2_pytorch/naturalspeech2_pytorch.py index b2de39e..d8e5ed4 100644 --- a/naturalspeech2_pytorch/naturalspeech2_pytorch.py +++ b/naturalspeech2_pytorch/naturalspeech2_pytorch.py @@ -14,7 +14,7 @@ import torchaudio from einops import rearrange, reduce, repeat -from einops.layers.torch import Rearrange +from einops.layers.torch import Rearrange, Reduce from audiolm_pytorch import SoundStream, EncodecWrapper from audiolm_pytorch.data import SoundDataset, get_dataloader @@ -68,6 +68,222 @@ def forward(self, x): fouriered = torch.cat((x, fouriered), dim = -1) return fouriered +# peripheral models + +# phoneme - pitch - speech prompt - duration predictors + +class PhonemeEncoder(nn.Module): + def __init__( + self, + dim = 512, + dim_hidden = 1024, + kernel_size = 9, + depth = 6, + dim_head = 64, + heads = 8, + conv_dropout = 0.2, + attn_dropout = 0., + use_flash = False + ): + super().__init__() + + same_padding = (kernel_size - 1) // 2 + + self.conv = nn.Sequential( + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, dim_hidden, kernel_size, padding = same_padding), + nn.SiLU(), + nn.Dropout(conv_dropout), + Rearrange('b c n -> b n c'), + ) + + self.transformer = Transformer( + dim = dim_hidden, + depth = depth, + dim_head = dim_head, + heads = heads, + dropout = attn_dropout, + use_flash = use_flash + ) + + def forward(self, x): + x = self.conv(x) + x = self.transformer(x) + return x + +class SpeechPromptEncoder(nn.Module): + + @beartype + def __init__( + self, + dim_codebook, + dims: Tuple[int] = (256, 2048, 2048, 2048, 2048, 512, 512, 512), + *, + depth, + heads = 8, + dim_head = 64, + dropout = 0.2, + kernel_size = 9, + padding = 4, + use_flash_attn = True + + ): + super().__init__() + + dims = [dim_codebook, *dims] + + self.dim, self.dim_out = dims[0], dims[-1] + + dim_pairs = zip(dims[:-1], dims[1:]) + + modules = [] + for dim_in, dim_out in dim_pairs: + modules.extend([ + nn.Conv1d(dim_in, dim_out, kernel_size, padding = padding), + nn.SiLU() + ]) + + self.conv = nn.Sequential( + Rearrange('b n c -> b c n'), + *modules, + Rearrange('b c n -> b n c') + ) + + self.transformer = Transformer( + dim = dims[-1], + depth = depth, + heads = heads, + dim_head = dim_head, + dropout = dropout, + use_flash = use_flash_attn + ) + + def forward(self, x): + assert x.shape[-1] == self.dim + + x = self.conv(x) + x = self.transformer(x) + return x + +# duration and pitch predictor seems to be the same + +class DurationPitchPredictor(nn.Module): + def __init__( + self, + *, + dim, + num_phoneme_tokens, + dim_encoded_prompts = None, + depth = 30, + kernel_size = 3, + heads = 8, + dim_head = 64, + dim_hidden = 512, + dropout = 0.2, + use_flash_attn = False + ): + super().__init__() + dim_encoded_prompts = default(dim_encoded_prompts, dim) + + self.phoneme_token_emb = nn.Embedding(num_phoneme_tokens, dim) + + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + nn.Sequential( + Rearrange('b n c -> b c n'), + nn.Conv1d(dim_hidden, dim_hidden, kernel_size, padding = kernel_size // 2), + nn.SiLU(), + nn.Dropout(dropout), + Rearrange('b c n -> b n c'), + ), + RMSNorm(dim), + Attention( + dim_hidden, + dim_context = dim_encoded_prompts, + heads = heads, + dim_head = dim_head, + dropout = dropout, + use_flash = use_flash_attn, + cross_attn_include_queries = True + ) + ])) + + self.to_pred = nn.Sequential( + nn.Linear(dim_hidden, 2), + nn.ReLU() + ) + + def forward( + self, + x, + encoded_prompts, + duration = None, + pitch = None + ): + x = self.phoneme_token_emb(x) + + for conv, norm, attn in self.layers: + x = conv(x) + x = attn(norm(x), encoded_prompts) + x + + duration_pred, pitch_pred = self.to_pred(x).unbind(dim = -1) + + duration_return = F.l1_loss(duration, duration_pred) if exists(duration) else duration_pred + pitch_return = F.l1_loss(pitch, pitch_pred) if exists(pitch) else pitch_pred + + return duration_return, pitch_return + +# use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198 +# in lieu of "q-k-v" attention with the m queries becoming key / values on which ddpm network is conditioned on + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_context = None, + num_latents = 64, # m in the paper + dim_head = 64, + heads = 8, + ff_mult = 4, + use_flash_attn = True + ): + super().__init__() + dim_context = default(dim_context, dim) + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std = 0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention( + dim = dim, + dim_head = dim_head, + dim_context = dim_context, + heads = heads, + use_flash_attn = use_flash_attn, + cross_attn_include_queries = True + ), + FeedForward(dim = dim, mult = ff_mult) + ])) + + self.norm = RMSNorm(dim) + + def forward(self, x): + batch = x.shape[0] + + latents = repeat(self.latents, 'n d -> b n d', b = batch) + + for attn, ff in self.layers: + latents = attn(latents, x) + latents + latents = ff(latents) + latents + + return self.norm(latents) + # model, which is wavenet + transformer class CausalConv1d(nn.Conv1d): @@ -92,15 +308,15 @@ def __init__( dilation, kernel_size = 3, skip_conv = False, - dim_time_mult = None + dim_cond_mult = None ): super().__init__() - self.cond_time = exists(dim_time_mult) + self.cond = exists(dim_cond_mult) self.to_time_cond = None - if self.cond_time: - self.to_time_cond = nn.Linear(dim * dim_time_mult, dim * 2) + if self.cond: + self.to_time_cond = nn.Linear(dim * dim_cond_mult, dim * 2) self.conv = CausalConv1d(dim, dim, kernel_size, dilation = dilation) self.res_conv = CausalConv1d(dim, dim, 1) @@ -108,7 +324,7 @@ def __init__( def forward(self, x, t = None): - if self.cond_time: + if self.cond: assert exists(t) t = self.to_time_cond(t) t = rearrange(t, 'b c -> b c 1') @@ -118,7 +334,7 @@ def forward(self, x, t = None): x = self.conv(x) - if self.cond_time: + if self.cond: x = x * t_gamma + t_beta x = x.tanh() * x.sigmoid() @@ -140,7 +356,7 @@ def __init__( layers, kernel_size = 3, has_skip = False, - dim_time_mult = None + dim_cond_mult = None ): super().__init__() dilations = 2 ** torch.arange(layers) @@ -154,7 +370,7 @@ def __init__( kernel_size = kernel_size, dilation = dilation, skip_conv = has_skip, - dim_time_mult = dim_time_mult + dim_cond_mult = dim_cond_mult ) self.blocks.append(block) @@ -185,7 +401,7 @@ def __init__( stacks, layers, init_conv_kernel = 3, - dim_time_mult = None + dim_cond_mult = None ): super().__init__() self.init_conv = CausalConv1d(dim, dim, init_conv_kernel) @@ -197,7 +413,7 @@ def __init__( stack = WavenetStack( dim, layers = layers, - dim_time_mult = dim_time_mult, + dim_cond_mult = dim_cond_mult, has_skip = is_last ) @@ -234,26 +450,26 @@ def __init__( heads = 8, ff_mult = 4, ff_causal_conv = False, - dim_time_mult = None, + dim_cond_mult = None, use_flash = False ): super().__init__() self.dim = dim self.layers = mlist([]) - cond_time = exists(dim_time_mult) + cond = exists(dim_cond_mult) - self.to_time_cond = None - self.cond_time = cond_time + self.to_cond = None + self.cond = cond - if cond_time: - self.to_time_cond = nn.Linear(dim * dim_time_mult, dim * 4) + if cond: + self.to_cond = nn.Linear(dim * dim_cond_mult, dim * 4) for _ in range(depth): self.layers.append(mlist([ - RMSNorm(dim, scale = not cond_time), + RMSNorm(dim, scale = not cond), Attention(dim = dim, dim_head = dim_head, heads = heads, use_flash = use_flash), - RMSNorm(dim, scale = not cond_time), + RMSNorm(dim, scale = not cond), FeedForward(dim = dim, mult = ff_mult, causal_conv = ff_causal_conv) ])) @@ -267,9 +483,9 @@ def forward( x, times = None ): - if self.cond_time: + if self.cond: assert exists(times) - t = self.to_time_cond(times) + t = self.to_cond(times) t = rearrange(t, 'b d -> b 1 d') t_attn_gamma, t_attn_beta, t_ff_gamma, t_ff_beta = t.chunk(4, dim = -1) @@ -277,7 +493,7 @@ def forward( res = x x = attn_norm(x) - if self.cond_time: + if self.cond: x = x * t_attn_gamma + t_attn_beta x = attn(x) + res @@ -285,7 +501,7 @@ def forward( res = x x = ff_norm(x) - if self.cond_time: + if self.cond: x = x * t_ff_gamma + t_ff_beta x = ff(x) + res @@ -293,6 +509,8 @@ def forward( return self.to_pred(x) class Model(nn.Module): + + @beartype def __init__( self, dim, @@ -303,15 +521,18 @@ def __init__( ff_mult = 4, wavenet_layers = 8, wavenet_stacks = 4, - dim_time_mult = 4, - use_flash_attn = True + dim_cond_mult = 4, + use_flash_attn = True, + speech_prompt_encoder: Optional[SpeechPromptEncoder] = None, + dim_prompt = None, + num_latents_m = 32 # number of latents to be perceiver resampled ('q-k-v' with 'm' queries in the paper) ): super().__init__() self.dim = dim # time condition - dim_time = dim * dim_time_mult + dim_time = dim * dim_cond_mult self.to_time_cond = Sequential( LearnedSinusoidalPosEmb(dim), @@ -319,13 +540,35 @@ def __init__( nn.SiLU() ) + # prompt condition + + self.condition_on_prompt = exists(speech_prompt_encoder) + + self.to_prompt_cond = None + + if self.condition_on_prompt: + dim_prompt = default(dim_prompt, speech_prompt_encoder.dim_out) + + self.speech_prompt_encoder = speech_prompt_encoder + assert speech_prompt_encoder.dim_out == dim_prompt + + self.to_prompt_cond = Sequential( + Reduce('b n d -> b d', 'mean'), + nn.Linear(dim_prompt, dim_time), + nn.SiLU() + ) + + # conditioning includes time and optionally prompt + + dim_cond_mult = dim_cond_mult * (2 if self.condition_on_prompt else 1) + # wavenet self.wavenet = Wavenet( dim = dim, stacks = wavenet_stacks, layers = wavenet_layers, - dim_time_mult = dim_time_mult + dim_cond_mult = dim_cond_mult ) # transformer @@ -337,17 +580,25 @@ def __init__( heads = heads, ff_mult = ff_mult, ff_causal_conv = True, - dim_time_mult = dim_time_mult, + dim_cond_mult = dim_cond_mult, use_flash = use_flash_attn, ) def forward( self, x, - times + times, + prompt = None ): t = self.to_time_cond(times) + if exists(self.to_prompt_cond): + assert exists(prompt) + encoded_prompt = self.speech_prompt_encoder(prompt) + + p = self.to_prompt_cond(encoded_prompt) + t = torch.cat((t, p), dim = -1) + x = rearrange(x, 'b n d -> b d n') x = self.wavenet(x, t) x = rearrange(x, 'b d n -> b n d') @@ -470,223 +721,6 @@ def forward(self, x): return self.norm(x) -# phoneme - pitch - speech prompt - duration predictors - -class PhonemeEncoder(nn.Module): - def __init__( - self, - dim = 512, - dim_hidden = 1024, - kernel_size = 9, - depth = 6, - dim_head = 64, - heads = 8, - conv_dropout = 0.2, - attn_dropout = 0., - use_flash = False - ): - super().__init__() - - same_padding = (kernel_size - 1) // 2 - - self.conv = nn.Sequential( - Rearrange('b n c -> b c n'), - nn.Conv1d(dim, dim_hidden, kernel_size, padding = same_padding), - nn.SiLU(), - nn.Dropout(conv_dropout), - Rearrange('b c n -> b n c'), - ) - - self.transformer = Transformer( - dim = dim_hidden, - depth = depth, - dim_head = dim_head, - heads = heads, - dropout = attn_dropout, - use_flash = use_flash - ) - - def forward(self, x): - x = self.conv(x) - x = self.transformer(x) - return x - -class SpeechPromptEncoder(nn.Module): - - @beartype - def __init__( - self, - codec: Optional[Union[SoundStream, EncodecWrapper]] = None, - dims: Tuple[int] = (256, 2048, 2048, 2048, 2048, 512, 512, 512), - *, - depth, - heads = 8, - dim_head = 64, - dropout = 0.2, - kernel_size = 9, - padding = 4, - use_flash_attn = True - - ): - super().__init__() - - self.codec = codec - - if exists(codec): - dims = [codec.codebook_dim, *dims] - - dim_pairs = zip(dims[:-1], dims[1:]) - - modules = [] - for dim_in, dim_out in dim_pairs: - modules.extend([ - nn.Conv1d(dim_in, dim_out, kernel_size, padding = padding), - nn.SiLU() - ]) - - self.conv = nn.Sequential( - Rearrange('b n c -> b c n'), - *modules, - Rearrange('b c n -> b n c') - ) - - self.transformer = Transformer( - dim = dims[-1], - depth = depth, - heads = heads, - dim_head = dim_head, - dropout = dropout, - use_flash = use_flash_attn - ) - - def forward(self, x): - is_raw_audio = x.ndim == 2 - assert not (is_raw_audio and not exists(self.codec)) - - if exists(self.codec) and is_raw_audio: - with torch.no_grad(): - self.codec.eval() - x, *_ = self.codec(x, return_encoded = True) - - x = self.conv(x) - x = self.transformer(x) - return x - -# duration and pitch predictor seems to be the same - -class DurationPitchPredictor(nn.Module): - def __init__( - self, - *, - dim, - num_phoneme_tokens, - dim_encoded_prompts = None, - depth = 30, - kernel_size = 3, - heads = 8, - dim_head = 64, - dim_hidden = 512, - dropout = 0.2, - use_flash_attn = False - ): - super().__init__() - dim_encoded_prompts = default(dim_encoded_prompts, dim) - - self.phoneme_token_emb = nn.Embedding(num_phoneme_tokens, dim) - - self.layers = nn.ModuleList([]) - - for _ in range(depth): - self.layers.append(nn.ModuleList([ - nn.Sequential( - Rearrange('b n c -> b c n'), - nn.Conv1d(dim_hidden, dim_hidden, kernel_size, padding = kernel_size // 2), - nn.SiLU(), - nn.Dropout(dropout), - Rearrange('b c n -> b n c'), - ), - RMSNorm(dim), - Attention( - dim_hidden, - dim_context = dim_encoded_prompts, - heads = heads, - dim_head = dim_head, - dropout = dropout, - use_flash = use_flash_attn, - cross_attn_include_queries = True - ) - ])) - - self.to_pred = nn.Sequential( - nn.Linear(dim_hidden, 2), - nn.ReLU() - ) - - def forward( - self, - x, - encoded_prompts, - duration = None, - pitch = None - ): - x = self.phoneme_token_emb(x) - - for conv, norm, attn in self.layers: - x = conv(x) - x = attn(norm(x), encoded_prompts) + x - - duration_pred, pitch_pred = self.to_pred(x).unbind(dim = -1) - - duration_return = F.l1_loss(duration, duration_pred) if exists(duration) else duration_pred - pitch_return = F.l1_loss(pitch, pitch_pred) if exists(pitch) else pitch_pred - - return duration_return, pitch_return - -# use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198 -# in lieu of "q-k-v" attention with the m queries becoming key / values on which ddpm network is conditioned on - -class PerceiverResampler(nn.Module): - def __init__( - self, - *, - dim, - depth, - num_latents = 64, # m in the paper - dim_head = 64, - heads = 8, - ff_mult = 4, - use_flash_attn = True - ): - super().__init__() - self.latents = nn.Parameter(torch.randn(num_latents, dim)) - nn.init.normal_(self.latents, std = 0.02) - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append(nn.ModuleList([ - Attention( - dim = dim, - dim_head = dim_head, - heads = heads, - use_flash_attn = use_flash_attn, - cross_attn_include_queries = True - ), - FeedForward(dim = dim, mult = ff_mult) - ])) - - self.norm = RMSNorm(dim) - - def forward(self, x): - batch = x.shape[0] - - latents = repeat(self.latents, 'n d -> b n d', b = batch) - - for attn, ff in self.layers: - latents = attn(latents, x) + latents - latents = ff(latents) + latents - - return self.norm(latents) - # tensor helper functions def log(t, eps = 1e-20): @@ -825,7 +859,7 @@ def get_sampling_timesteps(self, batch, *, device): return times @torch.no_grad() - def ddpm_sample(self, shape, time_difference = None): + def ddpm_sample(self, shape, prompt = None, time_difference = None): batch, device = shape[0], self.device time_difference = default(time_difference, self.time_difference) @@ -847,7 +881,7 @@ def ddpm_sample(self, shape, time_difference = None): # get predicted x0 - model_output = self.model(audio, noise_cond) + model_output = self.model(audio, noise_cond, prompt = prompt) # get log(snr) @@ -894,7 +928,7 @@ def ddpm_sample(self, shape, time_difference = None): return audio @torch.no_grad() - def ddim_sample(self, shape, time_difference = None): + def ddim_sample(self, shape, prompt = None, time_difference = None): batch, device = shape[0], self.device time_difference = default(time_difference, self.time_difference) @@ -924,7 +958,7 @@ def ddim_sample(self, shape, time_difference = None): # predict x0 - model_output = self.model(audio, times) + model_output = self.model(audio, times, prompt = prompt) # calculate x0 and noise @@ -947,15 +981,38 @@ def ddim_sample(self, shape, time_difference = None): return audio + def process_prompt(self, prompt = None): + if not exists(prompt): + return None + + assert self.model.condition_on_prompt + + is_raw_prompt = prompt.ndim == 2 + assert not (is_raw_prompt and not exists(self.codec)), 'codec must be passed in if one were to train on raw prompt' + + if is_raw_prompt: + with torch.no_grad(): + self.codec.eval() + prompt, _, _ = self.codec(prompt, return_encoded = True) + + return prompt + @torch.no_grad() def sample( self, *, length, + prompt = None, batch_size = 1 ): sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample - audio = sample_fn((batch_size, length, self.dim)) + + prompt = self.process_prompt(prompt) + + if exists(prompt): + batch_size = prompt.shape[0] + + audio = sample_fn((batch_size, length, self.dim), prompt = prompt) if exists(self.codec): audio = self.codec.decode(audio) @@ -969,6 +1026,7 @@ def forward( self, audio, codes = None, + prompt = None, *args, **kwargs ): @@ -981,6 +1039,8 @@ def forward( self.codec.eval() audio, codes, _ = self.codec(audio, return_encoded = True) + prompt = self.process_prompt(prompt) + batch, n, d, device = *audio.shape, self.device assert d == self.dim, f'codec codebook dimension {d} must match model dimensions {self.dim}' @@ -1001,7 +1061,7 @@ def forward( # predict and take gradient step - pred = self.model(noised_audio, times) + pred = self.model(noised_audio, times, prompt = prompt) if self.objective == 'eps': target = noise diff --git a/naturalspeech2_pytorch/version.py b/naturalspeech2_pytorch/version.py index 7fb58ee..23a88d0 100644 --- a/naturalspeech2_pytorch/version.py +++ b/naturalspeech2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.21' +__version__ = '0.0.22'