Skip to content

Commit

Permalink
use the perceiver resampler, which should be more powerful than the q…
Browse files Browse the repository at this point in the history
…-k-v attention with m latents in this paper
  • Loading branch information
lucidrains committed May 11, 2023
1 parent e4e9bde commit f391bd0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,13 @@ trainer.train()
year = {2023}
}
```

```bibtex
@article{Alayrac2022FlamingoAV,
title = {Flamingo: a Visual Language Model for Few-Shot Learning},
author = {Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.14198}
}
```
38 changes: 38 additions & 0 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,44 @@ def forward(

return F.l1_loss(pred, labels)

# use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198
# in lieu of "q-k-v" attention with the m queries becoming key / values on which ddpm network is conditioned on

class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
num_latents = 64, # m in the paper
dim_head = 64,
heads = 8,
ff_mult = 4
):
super().__init__()
self.latents = nn.Parameter(torch.randn(num_latents, dim))
nn.init.normal_(self.latents, std = 0.02)

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attn_include_queries = True),
FeedForward(dim = dim, mult = ff_mult)
]))

self.norm = RMSNorm(dim)

def forward(self, x):
batch = x.shape[0]

latents = repeat(self.latents, 'n d -> b n d', b = batch)

for attn, ff in self.layers:
latents = attn(latents, x) + latents
latents = ff(latents) + latents

return self.norm(latents)

# tensor helper functions

def log(t, eps = 1e-20):
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.17'
__version__ = '0.0.18'

0 comments on commit f391bd0

Please sign in to comment.