Skip to content

Commit

Permalink
add a hypersphere vit, adapted from https://arxiv.org/abs/2410.01131
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 9, 2024
1 parent 82f2fa7 commit 43cbcad
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2133,4 +2133,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@inproceedings{Loshchilov2024nGPTNT,
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273026160}
}
```

*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.14',
version = '1.8.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
260 changes: 260 additions & 0 deletions vit_pytorch/normalized_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

# functions

def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def divisible_by(numer, denom):
return (numer % denom) == 0

def l2norm(t, dim = -1):
return F.normalize(t, dim = dim, p = 2)

# for use with parametrize

class L2Norm(Module):
def __init__(self, dim = -1):
super().__init__()
self.dim = dim

def forward(self, t):
return l2norm(t, dim = self.dim)

class NormLinear(Module):
def __init__(
self,
dim,
dim_out,
norm_dim_in = True
):
super().__init__()
self.linear = nn.Linear(dim, dim_out, bias = False)

parametrize.register_parametrization(
self.linear,
'weight',
L2Norm(dim = -1 if norm_dim_in else 0)
)

@property
def weight(self):
return self.linear.weight

def forward(self, x):
return self.linear(x)

# attention and feedforward

class Attention(Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads
self.to_q = NormLinear(dim, dim_inner)
self.to_k = NormLinear(dim, dim_inner)
self.to_v = NormLinear(dim, dim_inner)

self.dropout = dropout

self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25))

self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')

self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)

def forward(
self,
x
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

q, k, v = map(self.split_heads, (q, k, v))

# query key rmsnorm

q, k = map(l2norm, (q, k))
q, k = (q * self.qk_scale), (k * self.qk_scale)

# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16

out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.,
scale = 1.
)

out = self.merge_heads(out)
return self.to_out(out)

class FeedForward(Module):
def __init__(
self,
dim,
*,
dim_inner,
dropout = 0.
):
super().__init__()
dim_inner = int(dim_inner * 2 / 3)

self.dim = dim
self.dropout = nn.Dropout(dropout)

self.to_hidden = NormLinear(dim, dim_inner)
self.to_gate = NormLinear(dim, dim_inner)

self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
self.gate_scale = nn.Parameter(torch.ones(dim_inner))

self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)

def forward(self, x):
hidden, gate = self.to_hidden(x), self.to_gate(x)

hidden = hidden * self.hidden_scale
gate = gate * self.gate_scale * (self.dim ** 0.5)

hidden = F.silu(gate) * hidden

hidden = self.dropout(hidden)
return self.to_out(hidden)

# classes

class nViT(Module):
""" https://arxiv.org/abs/2410.01131 """

def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
dropout = 0.,
channels = 3,
dim_head = 64,
residual_lerp_scale_init = None
):
super().__init__()
image_height, image_width = pair(image_size)

# calculate patching related stuff

assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)
num_patches = patch_height_dim * patch_width_dim

self.channels = channels
self.patch_size = patch_size

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.abs_pos_emb = nn.Embedding(num_patches, dim)

residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)

# layers

self.dim = dim
self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
]))

self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
]))

self.logit_scale = nn.Parameter(torch.ones(num_classes))

self.to_pred = NormLinear(dim, num_classes)

@torch.no_grad()
def norm_weights_(self):
for module in self.modules():
if not isinstance(module, NormLinear):
continue

normed = module.weight
original = module.linear.parametrizations.weight.original

original.copy_(normed)

def forward(self, images):
device = images.device

tokens = self.to_patch_embedding(images)

pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))

tokens = l2norm(tokens + pos_emb)

for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):

attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))

ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))

pooled = reduce(tokens, 'b n d -> b d', 'mean')

logits = self.to_pred(pooled)
logits = logits * self.logit_scale * (self.dim ** 0.5)

return logits

# quick test

if __name__ == '__main__':

v = nViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)

img = torch.randn(4, 3, 256, 256)
logits = v(img) # (4, 1000)
assert logits.shape == (4, 1000)
4 changes: 2 additions & 2 deletions vit_pytorch/rvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.amp import autocast

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# rotary embeddings

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
Expand Down

0 comments on commit 43cbcad

Please sign in to comment.