Skip to content

Commit

Permalink
redo duration pitch predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 12, 2023
1 parent 7d2feee commit 473fbce
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 36 deletions.
3 changes: 2 additions & 1 deletion naturalspeech2_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
Wavenet,
Model,
Trainer,
PhonemeEncoder
PhonemeEncoder,
DurationPitchPredictor
)

from audiolm_pytorch import (
Expand Down
77 changes: 43 additions & 34 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def __init__(
self,
dim,
*,
dim_context = None,
causal = False,
dim_head = 64,
heads = 8,
Expand All @@ -400,10 +401,11 @@ def __init__(
self.cross_attn_include_queries = cross_attn_include_queries

dim_inner = dim_head * heads
dim_context = default(dim_context, dim)

self.attend = Attend(causal = causal, dropout = dropout, use_flash = use_flash)
self.to_q = nn.Linear(dim, dim_inner, bias = False)
self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False)
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias = False)
self.to_out = nn.Linear(dim_inner, dim, bias = False)

def forward(self, x, context = None):
Expand Down Expand Up @@ -558,9 +560,13 @@ def forward(self, x):

# duration and pitch predictor seems to be the same

class DurationOrPitchPredictor(nn.Module):
class DurationPitchPredictor(nn.Module):
def __init__(
self,
*,
dim,
num_phoneme_tokens,
dim_encoded_prompts = None,
depth = 30,
kernel_size = 3,
heads = 8,
Expand All @@ -570,54 +576,57 @@ def __init__(
use_flash_attn = False
):
super().__init__()
dim_encoded_prompts = default(dim_encoded_prompts, dim)

conv_layers = []

for _ in range(depth):
conv_layers.extend([
nn.Conv1d(dim_hidden, dim_hidden, kernel_size, padding = kernel_size // 2),
nn.ReLU(),
nn.Dropout(dropout)
])
self.phoneme_token_emb = nn.Embedding(num_phoneme_tokens, dim)

self.conv_layers = nn.Sequential(*conv_layers)
self.layers = nn.ModuleList([])

self.attn_layers = nn.ModuleList([
Attention(
dim_hidden,
heads = heads,
dim_head = dim_head,
dropout = dropout,
use_flash_attn = use_flash_attn,
cross_attn_include_queries = True
)
for _ in range(attn_depth)
])
for _ in range(depth):
self.layers.append(nn.ModuleList([
nn.Sequential(
Rearrange('b n c -> b c n'),
nn.Conv1d(dim_hidden, dim_hidden, kernel_size, padding = kernel_size // 2),
nn.SiLU(),
nn.Dropout(dropout),
Rearrange('b c n -> b n c'),
),
RMSNorm(dim),
Attention(
dim_hidden,
dim_context = dim_encoded_prompts,
heads = heads,
dim_head = dim_head,
dropout = dropout,
use_flash = use_flash_attn,
cross_attn_include_queries = True
)
]))

self.to_pred = nn.Sequential(
Rearrange('b c n -> b n c'),
nn.Linear(hidden_size, 1),
nn.ReLU(),
Rearrange('b n 1 -> b n')
nn.Linear(dim_hidden, 2),
nn.ReLU()
)

def forward(
self,
x,
encoded_prompts,
labels = None
duration = None,
pitch = None
):
x = self.conv_layers(x)
x = self.phoneme_token_emb(x)

for attn in self.attn_layers:
x = attn(x, encoded_prompts)
for conv, norm, attn in self.layers:
x = conv(x)
x = attn(norm(x), encoded_prompts) + x

pred = self.to_pred(x)
duration_pred, pitch_pred = self.to_pred(x).unbind(dim = -1)

if not exists(labels):
return pred
duration_return = F.l1_loss(duration, duration_pred) if exists(duration) else duration_pred
pitch_return = F.l1_loss(pitch, pitch_pred) if exists(pitch) else pitch_pred

return F.l1_loss(pred, labels)
return duration_return, pitch_return

# use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198
# in lieu of "q-k-v" attention with the m queries becoming key / values on which ddpm network is conditioned on
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.19'
__version__ = '0.0.20'

0 comments on commit 473fbce

Please sign in to comment.