diff --git a/README.md b/README.md index e8663c4..9e9b57e 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ - [CrossFormer](#crossformer) - [RegionViT](#regionvit) - [NesT](#nest) +- [MobileViT](#mobilevit) - [Masked Autoencoder](#masked-autoencoder) - [Simple Masked Image Modeling](#simple-masked-image-modeling) - [Masked Patch Prediction](#masked-patch-prediction) @@ -549,6 +550,31 @@ img = torch.randn(1, 3, 224, 224) pred = nest(img) # (1, 1000) ``` +## MobileViT + + + +This paper introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different +perspective for the global processing of information with transformers. + +You can use it with the following code (ex. mobilevit_xs) + +``` +import torch +from vit_pytorch.mobile_vit import MobileViT + +mbvit_xs = MobileViT( + image_size=(256, 256), + dims = [96, 120, 144], + channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], + num_classes = 1000 +) + +img = torch.randn(1, 3, 256, 256) + +pred = mbvit_xs(img) # (1, 1000) +``` + ## Simple Masked Image Modeling diff --git a/images/mbvit.png b/images/mbvit.png new file mode 100644 index 0000000..503debf Binary files /dev/null and b/images/mbvit.png differ diff --git a/vit_pytorch/mobile_vit.py b/vit_pytorch/mobile_vit.py new file mode 100644 index 0000000..bb44d02 --- /dev/null +++ b/vit_pytorch/mobile_vit.py @@ -0,0 +1,229 @@ +""" +An implementation of MobileViT Model as defined in: +MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer +Arxiv: https://arxiv.org/abs/2110.02178 +Origin Code: https://github.com/murufeng/awesome_lightweight_networks +""" + +import torch +import torch.nn as nn + +from einops import rearrange + +def _make_divisible(v, divisor, min_value=None): + + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def Conv_BN_ReLU(inp, oup, kernel, stride=1): + return nn.Sequential( + nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.): + super().__init__() + self.ffn = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.ffn(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + out = torch.matmul(attn, v) + out = rearrange(out, 'b p h n d -> b p n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +class MV2Block(nn.Module): + def __init__(self, inp, oup, stride=1, expand_ratio=4): + super(MV2Block, self).__init__() + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.identity = stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.identity: + return x + self.conv(x) + else: + return self.conv(x) + +class MobileViTBlock(nn.Module): + def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): + super().__init__() + self.ph, self.pw = patch_size + + self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size) + self.conv2 = conv_1x1_bn(channel, dim) + + self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout) + + self.conv3 = conv_1x1_bn(dim, channel) + self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size) + + def forward(self, x): + y = x.clone() + + # Local representations + x = self.conv1(x) + x = self.conv2(x) + + # Global representations + _, _, h, w = x.shape + x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) + x = self.transformer(x) + x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, + pw=self.pw) + + # Fusion + x = self.conv3(x) + x = torch.cat((x, y), 1) + x = self.conv4(x) + return x + + +class MobileViT(nn.Module): + def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): + super().__init__() + ih, iw = image_size + ph, pw = patch_size + assert ih % ph == 0 and iw % pw == 0 + + L = [2, 4, 3] + + self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2) + + self.mv2 = nn.ModuleList([]) + self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) + self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) + self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) + self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) + self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) + self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) + self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) + + self.mvit = nn.ModuleList([]) + self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) + self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) + self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) + + self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) + + self.pool = nn.AvgPool2d(ih // 32, 1) + self.fc = nn.Linear(channels[-1], num_classes, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.mv2[0](x) + + x = self.mv2[1](x) + x = self.mv2[2](x) + x = self.mv2[3](x) + + x = self.mv2[4](x) + x = self.mvit[0](x) + + x = self.mv2[5](x) + x = self.mvit[1](x) + + x = self.mv2[6](x) + x = self.mvit[2](x) + x = self.conv2(x) + + x = self.pool(x).view(-1, x.shape[1]) + x = self.fc(x) + return x +