From 653d8f3d644eb24f463770cd02160e4ca16ce9ed Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 30 Oct 2024 08:57:38 -0700 Subject: [PATCH] attention is all you need for neural radiance fields --- README.md | 11 +++ setup.py | 2 +- tests/test_x_transformers.py | 18 ++++- x_transformers/__init__.py | 4 ++ x_transformers/neo_mlp.py | 126 +++++++++++++++++++++++++++++++++++ 5 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 x_transformers/neo_mlp.py diff --git a/README.md b/README.md index 30c70e41..6edb9096 100644 --- a/README.md +++ b/README.md @@ -2341,4 +2341,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) } ``` +```bibtex +@inproceedings{anonymous2024from, + title = {From {MLP} to Neo{MLP}: Leveraging Self-Attention for Neural Fields}, + author = {Anonymous}, + booktitle = {Submitted to The Thirteenth International Conference on Learning Representations}, + year = {2024}, + url = {https://openreview.net/forum?id=A8Vuf2e8y6}, + note = {under review} +} +``` + *solve intelligence... then use that to solve everything else.* - Demis Hassabis diff --git a/setup.py b/setup.py index 08159037..5a989f5a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.41.5', + version = '1.42.0', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index 82aae586..f8bd0baa 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -6,7 +6,8 @@ TransformerWrapper, Encoder, Decoder, - AutoregressiveWrapper + AutoregressiveWrapper, + NeoMLP ) from x_transformers.multi_input import MultiInputTransformerWrapper @@ -357,3 +358,18 @@ def test_forgetting_transformer(): x = torch.randint(0, 20000, (2, 1024)) embed = model(x) + +def test_neo_mlp(): + + mlp = NeoMLP( + dim_in = 5, + dim_out = 7, + dim_hidden = 16, + depth = 5, + dim_model = 64, + ) + + x = torch.randn(3, 5) + + out = mlp(x) + assert out.shape == (3, 7) diff --git a/x_transformers/__init__.py b/x_transformers/__init__.py index d5ef6db6..1c15544b 100644 --- a/x_transformers/__init__.py +++ b/x_transformers/__init__.py @@ -32,3 +32,7 @@ from x_transformers.dpo import ( DPO ) + +from x_transformers.neo_mlp import ( + NeoMLP +) diff --git a/x_transformers/neo_mlp.py b/x_transformers/neo_mlp.py new file mode 100644 index 00000000..b7e2fa9b --- /dev/null +++ b/x_transformers/neo_mlp.py @@ -0,0 +1,126 @@ +from collections import namedtuple + +import torch +from torch import nn, tensor, pi, is_tensor +import torch.nn.functional as F +from torch.nn import Module, ModuleList + +from einops import rearrange, repeat, einsum, pack, unpack + +from x_transformers.x_transformers import ( + Encoder +) + +# helpers + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +# random fourier + +class RandomFourierEmbed(Module): + + def __init__(self, dim): + super().__init__() + self.proj = nn.Linear(1, dim) + self.proj.requires_grad_(False) + + def forward( + self, + times, + ): + + times = rearrange(times, '... -> ... 1') + rand_proj = self.proj(times) + return torch.cos(2 * pi * rand_proj) + +# class + +class NeoMLP(Module): + """ https://openreview.net/forum?id=A8Vuf2e8y6 """ + + def __init__( + self, + *, + dim_in, + dim_hidden, + dim_out, + dim_model, + depth, + encoder_kwargs: dict = dict( + attn_dim_head = 16, + heads = 4 + ) + ): + super().__init__() + + # input and output embeddings + + self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model)) + self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model)) + self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model)) + + nn.init.normal_(self.input_embed, std = 0.02) + nn.init.normal_(self.hidden_embed, std = 0.02) + nn.init.normal_(self.output_embed, std = 0.02) + + # they use random fourier for continuous features + + self.random_fourier = nn.Sequential( + RandomFourierEmbed(dim_model), + nn.Linear(dim_model, dim_model) + ) + + # hidden dimensions of mlp replaced with nodes with message passing + # which comes back to self attention as a fully connected graph. + + self.transformer = Encoder( + dim = dim_model, + depth = depth, + **encoder_kwargs + ) + + # output + + self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model)) + self.to_output_bias = nn.Parameter(torch.zeros(dim_out)) + + def forward( + self, + x, + return_embeds = False + ): + batch = x.shape[0] + + fouriered_input = self.random_fourier(x) + + # add fouriered input to the input embedding + + input_embed = fouriered_input + self.input_embed + + hidden_embed, output_embed = tuple(repeat(t, '... -> b ...', b = batch) for t in (self.hidden_embed, self.output_embed)) + + # pack all the inputs into one string of tokens for self attention + + embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d') + + # attention is all you need + + embed = self.transformer(embed) + + # unpack + + input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d') + + # project for output + + output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n') + output = output + self.to_output_bias + + if not return_embeds: + return output + + return output, (input_embed, hidden_embed, output_embed)