diff --git a/README.md b/README.md index dcc795b..316d8b7 100644 --- a/README.md +++ b/README.md @@ -123,3 +123,13 @@ trainer.train() year = {2023} } ``` + +```bibtex +@article{Alayrac2022FlamingoAV, + title = {Flamingo: a Visual Language Model for Few-Shot Learning}, + author = {Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2204.14198} +} +``` diff --git a/naturalspeech2_pytorch/naturalspeech2_pytorch.py b/naturalspeech2_pytorch/naturalspeech2_pytorch.py index 5dc0f03..6ba3fe8 100644 --- a/naturalspeech2_pytorch/naturalspeech2_pytorch.py +++ b/naturalspeech2_pytorch/naturalspeech2_pytorch.py @@ -619,6 +619,44 @@ def forward( return F.l1_loss(pred, labels) +# use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198 +# in lieu of "q-k-v" attention with the m queries becoming key / values on which ddpm network is conditioned on + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth, + num_latents = 64, # m in the paper + dim_head = 64, + heads = 8, + ff_mult = 4 + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std = 0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attn_include_queries = True), + FeedForward(dim = dim, mult = ff_mult) + ])) + + self.norm = RMSNorm(dim) + + def forward(self, x): + batch = x.shape[0] + + latents = repeat(self.latents, 'n d -> b n d', b = batch) + + for attn, ff in self.layers: + latents = attn(latents, x) + latents + latents = ff(latents) + latents + + return self.norm(latents) + # tensor helper functions def log(t, eps = 1e-20): diff --git a/naturalspeech2_pytorch/version.py b/naturalspeech2_pytorch/version.py index b47451b..d9fc5d6 100644 --- a/naturalspeech2_pytorch/version.py +++ b/naturalspeech2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.17' +__version__ = '0.0.18'