diff --git a/setup.py b/setup.py index 3db75a6..4d9e283 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.2.9', + version = '1.4.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/simple_vit.py b/vit_pytorch/simple_vit.py index 2b63b60..f535693 100644 --- a/vit_pytorch/simple_vit.py +++ b/vit_pytorch/simple_vit.py @@ -64,6 +64,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ @@ -74,7 +75,7 @@ def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - return x + return self.norm(x) class SimpleViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): @@ -101,12 +102,10 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) - self.to_latent = nn.Identity() - self.linear_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) self.pool = "mean" + self.to_latent = nn.Identity() + + self.linear_head = nn.LayerNorm(dim) def forward(self, img): device = img.device diff --git a/vit_pytorch/simple_vit_1d.py b/vit_pytorch/simple_vit_1d.py index 1cca5a5..233f834 100644 --- a/vit_pytorch/simple_vit_1d.py +++ b/vit_pytorch/simple_vit_1d.py @@ -62,6 +62,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ @@ -72,7 +73,7 @@ def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - return x + return self.norm(x) class SimpleViT(nn.Module): def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): @@ -93,10 +94,7 @@ def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_d self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.to_latent = nn.Identity() - self.linear_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) + self.linear_head = nn.Linear(dim, num_classes) def forward(self, series): *_, n, dtype = *series.shape, series.dtype diff --git a/vit_pytorch/simple_vit_3d.py b/vit_pytorch/simple_vit_3d.py index 03691ce..8a1460f 100644 --- a/vit_pytorch/simple_vit_3d.py +++ b/vit_pytorch/simple_vit_3d.py @@ -77,6 +77,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ @@ -87,7 +88,7 @@ def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - return x + return self.norm(x) class SimpleViT(nn.Module): def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): @@ -111,10 +112,7 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.to_latent = nn.Identity() - self.linear_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) + self.linear_head = nn.Linear(dim, num_classes) def forward(self, video): *_, h, w, dtype = *video.shape, video.dtype diff --git a/vit_pytorch/simple_vit_with_patch_dropout.py b/vit_pytorch/simple_vit_with_patch_dropout.py index e37d1e2..0e0e040 100644 --- a/vit_pytorch/simple_vit_with_patch_dropout.py +++ b/vit_pytorch/simple_vit_with_patch_dropout.py @@ -87,6 +87,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ @@ -97,7 +98,7 @@ def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - return x + return self.norm(x) class SimpleViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5): @@ -122,10 +123,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.to_latent = nn.Identity() - self.linear_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) + self.linear_head = nn.Linear(dim, num_classes) def forward(self, img): *_, h, w, dtype = *img.shape, img.dtype diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 796c741..5b34a44 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -11,24 +11,18 @@ def pair(t): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) + def forward(self, x): return self.net(x) @@ -41,6 +35,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -52,6 +48,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -67,17 +65,20 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) + def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - return x + + return self.norm(x) class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): @@ -107,10 +108,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml self.pool = pool self.to_latent = nn.Identity() - self.mlp_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) + self.mlp_head = nn.Linear(dim, num_classes) def forward(self, img): x = self.to_patch_embedding(img) diff --git a/vit_pytorch/vit_with_patch_merger.py b/vit_pytorch/vit_with_patch_merger.py index 7f1360b..add3728 100644 --- a/vit_pytorch/vit_with_patch_merger.py +++ b/vit_pytorch/vit_with_patch_merger.py @@ -32,18 +32,11 @@ def forward(self, x): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -62,6 +55,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -73,6 +67,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -88,6 +83,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper @@ -95,8 +91,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_mer for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for index, (attn, ff) in enumerate(self.layers): @@ -106,7 +102,7 @@ def forward(self, x): if index == self.patch_merge_layer_index: x = self.patch_merger(x) - return x + return self.norm(x) class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): @@ -133,7 +129,6 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml self.mlp_head = nn.Sequential( Reduce('b n d -> b d', 'mean'), - nn.LayerNorm(dim), nn.Linear(dim, num_classes) ) diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index 50daa65..2df8f01 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -70,6 +70,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() + self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ @@ -80,7 +81,7 @@ def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - return x + return self.norm(x) class ViT(nn.Module): def __init__( @@ -137,10 +138,7 @@ def __init__( self.pool = pool self.to_latent = nn.Identity() - self.mlp_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) + self.mlp_head = nn.Linear(dim, num_classes) def forward(self, video): x = self.to_patch_embedding(video)