diff --git a/naturalspeech2_pytorch/__init__.py b/naturalspeech2_pytorch/__init__.py index 3e360d8..ce58930 100644 --- a/naturalspeech2_pytorch/__init__.py +++ b/naturalspeech2_pytorch/__init__.py @@ -11,7 +11,8 @@ Wavenet, Model, Trainer, - PhonemeEncoder + PhonemeEncoder, + DurationPitchPredictor ) from audiolm_pytorch import ( diff --git a/naturalspeech2_pytorch/naturalspeech2_pytorch.py b/naturalspeech2_pytorch/naturalspeech2_pytorch.py index 5b633b9..404ec51 100644 --- a/naturalspeech2_pytorch/naturalspeech2_pytorch.py +++ b/naturalspeech2_pytorch/naturalspeech2_pytorch.py @@ -387,6 +387,7 @@ def __init__( self, dim, *, + dim_context = None, causal = False, dim_head = 64, heads = 8, @@ -400,10 +401,11 @@ def __init__( self.cross_attn_include_queries = cross_attn_include_queries dim_inner = dim_head * heads + dim_context = default(dim_context, dim) self.attend = Attend(causal = causal, dropout = dropout, use_flash = use_flash) self.to_q = nn.Linear(dim, dim_inner, bias = False) - self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias = False) self.to_out = nn.Linear(dim_inner, dim, bias = False) def forward(self, x, context = None): @@ -558,9 +560,13 @@ def forward(self, x): # duration and pitch predictor seems to be the same -class DurationOrPitchPredictor(nn.Module): +class DurationPitchPredictor(nn.Module): def __init__( self, + *, + dim, + num_phoneme_tokens, + dim_encoded_prompts = None, depth = 30, kernel_size = 3, heads = 8, @@ -570,54 +576,57 @@ def __init__( use_flash_attn = False ): super().__init__() + dim_encoded_prompts = default(dim_encoded_prompts, dim) - conv_layers = [] - - for _ in range(depth): - conv_layers.extend([ - nn.Conv1d(dim_hidden, dim_hidden, kernel_size, padding = kernel_size // 2), - nn.ReLU(), - nn.Dropout(dropout) - ]) + self.phoneme_token_emb = nn.Embedding(num_phoneme_tokens, dim) - self.conv_layers = nn.Sequential(*conv_layers) + self.layers = nn.ModuleList([]) - self.attn_layers = nn.ModuleList([ - Attention( - dim_hidden, - heads = heads, - dim_head = dim_head, - dropout = dropout, - use_flash_attn = use_flash_attn, - cross_attn_include_queries = True - ) - for _ in range(attn_depth) - ]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + nn.Sequential( + Rearrange('b n c -> b c n'), + nn.Conv1d(dim_hidden, dim_hidden, kernel_size, padding = kernel_size // 2), + nn.SiLU(), + nn.Dropout(dropout), + Rearrange('b c n -> b n c'), + ), + RMSNorm(dim), + Attention( + dim_hidden, + dim_context = dim_encoded_prompts, + heads = heads, + dim_head = dim_head, + dropout = dropout, + use_flash = use_flash_attn, + cross_attn_include_queries = True + ) + ])) self.to_pred = nn.Sequential( - Rearrange('b c n -> b n c'), - nn.Linear(hidden_size, 1), - nn.ReLU(), - Rearrange('b n 1 -> b n') + nn.Linear(dim_hidden, 2), + nn.ReLU() ) def forward( self, x, encoded_prompts, - labels = None + duration = None, + pitch = None ): - x = self.conv_layers(x) + x = self.phoneme_token_emb(x) - for attn in self.attn_layers: - x = attn(x, encoded_prompts) + for conv, norm, attn in self.layers: + x = conv(x) + x = attn(norm(x), encoded_prompts) + x - pred = self.to_pred(x) + duration_pred, pitch_pred = self.to_pred(x).unbind(dim = -1) - if not exists(labels): - return pred + duration_return = F.l1_loss(duration, duration_pred) if exists(duration) else duration_pred + pitch_return = F.l1_loss(pitch, pitch_pred) if exists(pitch) else pitch_pred - return F.l1_loss(pred, labels) + return duration_return, pitch_return # use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198 # in lieu of "q-k-v" attention with the m queries becoming key / values on which ddpm network is conditioned on diff --git a/naturalspeech2_pytorch/version.py b/naturalspeech2_pytorch/version.py index f12684f..3106bbe 100644 --- a/naturalspeech2_pytorch/version.py +++ b/naturalspeech2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.19' +__version__ = '0.0.20'