Skip to content

Commit

Permalink
complete classifier free guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 12, 2023
1 parent ab9cba4 commit 5974845
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 15 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ prompt_encoder = SpeechPromptEncoder(
model = Model(
dim = 128,
depth = 6,
speech_prompt_encoder = prompt_encoder # pass in the SpeechPromptEncoder
speech_prompt_encoder = prompt_encoder, # pass in the SpeechPromptEncoder
cond_drop_prob = 0.25 # dropout prompt conditioning with this probability, for classifier free guidance
)

# natural speech diffusion model
Expand All @@ -110,7 +111,7 @@ loss.backward()
# do the above in a loop for a lot of raw audio data...
# then you can sample from your generative model as so

generated_audio = diffusion.sample(length = 1024, prompt = prompt) # pass in your prompt
generated_audio = diffusion.sample(length = 1024, prompt = prompt, cond_scale = 3.) # pass in your prompt - classifier free guidance scale of 3 (1 would be no classifier free guidance)
```

Or if you want a `Trainer` class to take care of the training and sampling loop, just simply do
Expand All @@ -131,8 +132,8 @@ trainer.train()
## Todo

- [x] complete perceiver then cross attention conditioning on ddpm side
- [x] add classifier free guidance, even if not in paper

- [ ] add classifier free guidance, even if not in paper
- [ ] add self-conditioning on ddpm side
- [ ] complete duration / pitch prediction during training
- [ ] take care of automatic slicing of audio for prompt, being aware of minimal audio segment as allowed by the codec model
Expand Down
79 changes: 68 additions & 11 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def identity(t, *args, **kwargs):
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num

# tensor helpers

def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# sinusoidal positional embeds

class LearnedSinusoidalPosEmb(nn.Module):
Expand Down Expand Up @@ -535,7 +545,8 @@ def __init__(
speech_prompt_encoder: Optional[SpeechPromptEncoder] = None,
dim_prompt = None,
num_latents_m = 32, # number of latents to be perceiver resampled ('q-k-v' with 'm' queries in the paper)
resampler_depth = 2
resampler_depth = 2,
cond_drop_prob = 0.
):
super().__init__()
self.dim = dim
Expand All @@ -554,6 +565,7 @@ def __init__(

condition_on_prompt = exists(speech_prompt_encoder)

self.cond_drop_prob = cond_drop_prob # for classifier free guidance
self.condition_on_prompt = condition_on_prompt
self.to_prompt_cond = None

Expand All @@ -563,6 +575,12 @@ def __init__(
self.speech_prompt_encoder = speech_prompt_encoder
assert speech_prompt_encoder.dim_out == dim_prompt

self.null_prompt_cond = nn.Parameter(torch.randn(dim_time))
self.null_prompt_tokens = nn.Parameter(torch.randn(num_latents_m, dim))

nn.init.normal_(self.null_prompt_cond, std = 0.02)
nn.init.normal_(self.null_prompt_tokens, std = 0.02)

self.to_prompt_cond = Sequential(
Reduce('b n d -> b d', 'mean'),
nn.Linear(dim_prompt, dim_time),
Expand Down Expand Up @@ -606,23 +624,61 @@ def __init__(
cross_attn = condition_on_prompt
)

@property
def device(self):
return next(self.parameters()).device

def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

if cond_scale == 1.:
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)

return null_logits + (logits - null_logits) * cond_scale

def forward(
self,
x,
times,
prompt = None
prompt = None,
cond_drop_prob = None
):
b = x.shape[0]
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)

t = self.to_time_cond(times)
c = None

if exists(self.to_prompt_cond):
assert exists(prompt)
encoded_prompt = self.speech_prompt_encoder(prompt)

p = self.to_prompt_cond(encoded_prompt)
t = torch.cat((t, p), dim = -1)
prompt_cond = self.to_prompt_cond(encoded_prompt)

c = self.perceiver_resampler(encoded_prompt)
prompt_cond = torch.where(
rearrange(drop_mask, 'b -> b 1'),
self.null_prompt_cond,
prompt_cond,
)

t = torch.cat((t, prompt_cond), dim = -1)

resampled_prompt_tokens = self.perceiver_resampler(encoded_prompt)

c = torch.where(
rearrange(drop_mask, 'b -> b 1 1'),
self.null_prompt_tokens,
resampled_prompt_tokens
)

x = rearrange(x, 'b n d -> b d n')
x = self.wavenet(x, t)
Expand Down Expand Up @@ -884,7 +940,7 @@ def get_sampling_timesteps(self, batch, *, device):
return times

@torch.no_grad()
def ddpm_sample(self, shape, prompt = None, time_difference = None):
def ddpm_sample(self, shape, prompt = None, time_difference = None, cond_scale = 1.):
batch, device = shape[0], self.device

time_difference = default(time_difference, self.time_difference)
Expand All @@ -906,7 +962,7 @@ def ddpm_sample(self, shape, prompt = None, time_difference = None):

# get predicted x0

model_output = self.model(audio, noise_cond, prompt = prompt)
model_output = self.model.forward_with_cond_scale(audio, noise_cond, prompt = prompt, cond_scale = cond_scale)

# get log(snr)

Expand Down Expand Up @@ -953,7 +1009,7 @@ def ddpm_sample(self, shape, prompt = None, time_difference = None):
return audio

@torch.no_grad()
def ddim_sample(self, shape, prompt = None, time_difference = None):
def ddim_sample(self, shape, prompt = None, time_difference = None, cond_scale = 1.):
batch, device = shape[0], self.device

time_difference = default(time_difference, self.time_difference)
Expand Down Expand Up @@ -983,7 +1039,7 @@ def ddim_sample(self, shape, prompt = None, time_difference = None):

# predict x0

model_output = self.model(audio, times, prompt = prompt)
model_output = self.model.forward_with_cond_scale(audio, times, prompt = prompt, cond_scale = cond_scale)

# calculate x0 and noise

Expand Down Expand Up @@ -1028,7 +1084,8 @@ def sample(
*,
length,
prompt = None,
batch_size = 1
batch_size = 1,
cond_scale = 1.
):
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample

Expand All @@ -1037,7 +1094,7 @@ def sample(
if exists(prompt):
batch_size = prompt.shape[0]

audio = sample_fn((batch_size, length, self.dim), prompt = prompt)
audio = sample_fn((batch_size, length, self.dim), prompt = prompt, cond_scale = cond_scale)

if exists(self.codec):
audio = self.codec.decode(audio)
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.24'
__version__ = '0.0.25'

0 comments on commit 5974845

Please sign in to comment.