Skip to content

Commit

Permalink
complete basic film conditioning of the encoded speech prompt on ddpm…
Browse files Browse the repository at this point in the history
… side
  • Loading branch information
lucidrains committed May 12, 2023
1 parent badaf25 commit 790b1aa
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 254 deletions.
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ $ pip install naturalspeech2-pytorch
import torch
from naturalspeech2_pytorch import (
EncodecWrapper,
Transformer,
Model,
NaturalSpeech2
)

Expand Down Expand Up @@ -63,6 +63,56 @@ generated_audio = diffusion.sample(length = 1024) # (1, 327680)

```

To accept prompting, you will need to instantiate `SpeechPromptEncoder` and pass it into `speech_prompt_encoder` on the `Model`

ex.

```python
import torch
from naturalspeech2_pytorch import (
EncodecWrapper,
Model,
NaturalSpeech2,
SpeechPromptEncoder
)

# use encodec as an example

codec = EncodecWrapper()

prompt_encoder = SpeechPromptEncoder(
dim_codebook = codec.codebook_dim,
depth = 2
)

model = Model(
dim = 128,
depth = 6,
speech_prompt_encoder = prompt_encoder # pass in the SpeechPromptEncoder
)

# natural speech diffusion model

diffusion = NaturalSpeech2(
model = model,
codec = codec,
timesteps = 1000
).cuda()

# mock raw audio data

raw_audio = torch.randn(4, 327680).cuda()
prompt = torch.randn(4, 32768).cuda() # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically

loss = diffusion(raw_audio, prompt = prompt) # pass in the prompt
loss.backward()

# do the above in a loop for a lot of raw audio data...
# then you can sample from your generative model as so

generated_audio = diffusion.sample(length = 1024, prompt = prompt) # pass in your prompt
```

Or if you want a `Trainer` class to take care of the training and sampling loop, just simply do

```python
Expand All @@ -78,6 +128,13 @@ trainer = Trainer(
trainer.train()
```

## Todo

- [ ] complete perceiver then cross attention conditioning on ddpm side
- [ ] add classifier free guidance, even if not in paper
- [ ] add self-conditioning on ddpm side
- [ ] complete duration / pitch prediction during training

## Citations

```bibtex
Expand Down
Loading

0 comments on commit 790b1aa

Please sign in to comment.