Skip to content

Commit

Permalink
attention is all you need for neural radiance fields
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 30, 2024
1 parent 0fe1aed commit 653d8f3
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 2 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2341,4 +2341,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@inproceedings{anonymous2024from,
title = {From {MLP} to Neo{MLP}: Leveraging Self-Attention for Neural Fields},
author = {Anonymous},
booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
year = {2024},
url = {https://openreview.net/forum?id=A8Vuf2e8y6},
note = {under review}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.41.5',
version = '1.42.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
18 changes: 17 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
TransformerWrapper,
Encoder,
Decoder,
AutoregressiveWrapper
AutoregressiveWrapper,
NeoMLP
)

from x_transformers.multi_input import MultiInputTransformerWrapper
Expand Down Expand Up @@ -357,3 +358,18 @@ def test_forgetting_transformer():
x = torch.randint(0, 20000, (2, 1024))

embed = model(x)

def test_neo_mlp():

mlp = NeoMLP(
dim_in = 5,
dim_out = 7,
dim_hidden = 16,
depth = 5,
dim_model = 64,
)

x = torch.randn(3, 5)

out = mlp(x)
assert out.shape == (3, 7)
4 changes: 4 additions & 0 deletions x_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@
from x_transformers.dpo import (
DPO
)

from x_transformers.neo_mlp import (
NeoMLP
)
126 changes: 126 additions & 0 deletions x_transformers/neo_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections import namedtuple

import torch
from torch import nn, tensor, pi, is_tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList

from einops import rearrange, repeat, einsum, pack, unpack

from x_transformers.x_transformers import (
Encoder
)

# helpers

def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

# random fourier

class RandomFourierEmbed(Module):

def __init__(self, dim):
super().__init__()
self.proj = nn.Linear(1, dim)
self.proj.requires_grad_(False)

def forward(
self,
times,
):

times = rearrange(times, '... -> ... 1')
rand_proj = self.proj(times)
return torch.cos(2 * pi * rand_proj)

# class

class NeoMLP(Module):
""" https://openreview.net/forum?id=A8Vuf2e8y6 """

def __init__(
self,
*,
dim_in,
dim_hidden,
dim_out,
dim_model,
depth,
encoder_kwargs: dict = dict(
attn_dim_head = 16,
heads = 4
)
):
super().__init__()

# input and output embeddings

self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model))
self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model))
self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model))

nn.init.normal_(self.input_embed, std = 0.02)
nn.init.normal_(self.hidden_embed, std = 0.02)
nn.init.normal_(self.output_embed, std = 0.02)

# they use random fourier for continuous features

self.random_fourier = nn.Sequential(
RandomFourierEmbed(dim_model),
nn.Linear(dim_model, dim_model)
)

# hidden dimensions of mlp replaced with nodes with message passing
# which comes back to self attention as a fully connected graph.

self.transformer = Encoder(
dim = dim_model,
depth = depth,
**encoder_kwargs
)

# output

self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model))
self.to_output_bias = nn.Parameter(torch.zeros(dim_out))

def forward(
self,
x,
return_embeds = False
):
batch = x.shape[0]

fouriered_input = self.random_fourier(x)

# add fouriered input to the input embedding

input_embed = fouriered_input + self.input_embed

hidden_embed, output_embed = tuple(repeat(t, '... -> b ...', b = batch) for t in (self.hidden_embed, self.output_embed))

# pack all the inputs into one string of tokens for self attention

embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d')

# attention is all you need

embed = self.transformer(embed)

# unpack

input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d')

# project for output

output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
output = output + self.to_output_bias

if not return_embeds:
return output

return output, (input_embed, hidden_embed, output_embed)

0 comments on commit 653d8f3

Please sign in to comment.