diff --git a/README.md b/README.md
index e8663c4..9e9b57e 100644
--- a/README.md
+++ b/README.md
@@ -19,6 +19,7 @@
- [CrossFormer](#crossformer)
- [RegionViT](#regionvit)
- [NesT](#nest)
+- [MobileViT](#mobilevit)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
@@ -549,6 +550,31 @@ img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
+## MobileViT
+
+
+
+This paper introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
+perspective for the global processing of information with transformers.
+
+You can use it with the following code (ex. mobilevit_xs)
+
+```
+import torch
+from vit_pytorch.mobile_vit import MobileViT
+
+mbvit_xs = MobileViT(
+ image_size=(256, 256),
+ dims = [96, 120, 144],
+ channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
+ num_classes = 1000
+)
+
+img = torch.randn(1, 3, 256, 256)
+
+pred = mbvit_xs(img) # (1, 1000)
+```
+
## Simple Masked Image Modeling
diff --git a/images/mbvit.png b/images/mbvit.png
new file mode 100644
index 0000000..503debf
Binary files /dev/null and b/images/mbvit.png differ
diff --git a/vit_pytorch/mobile_vit.py b/vit_pytorch/mobile_vit.py
new file mode 100644
index 0000000..bb44d02
--- /dev/null
+++ b/vit_pytorch/mobile_vit.py
@@ -0,0 +1,229 @@
+"""
+An implementation of MobileViT Model as defined in:
+MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
+Arxiv: https://arxiv.org/abs/2110.02178
+Origin Code: https://github.com/murufeng/awesome_lightweight_networks
+"""
+
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+
+def _make_divisible(v, divisor, min_value=None):
+
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+def Conv_BN_ReLU(inp, oup, kernel, stride=1):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+def conv_1x1_bn(inp, oup):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.):
+ super().__init__()
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.ffn(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ ) if project_out else nn.Identity()
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+ attn = self.attend(dots)
+ out = torch.matmul(attn, v)
+ out = rearrange(out, 'b p h n d -> b p n (h d)')
+ return self.to_out(out)
+
+
+class Transformer(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
+ ]))
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return x
+
+class MV2Block(nn.Module):
+ def __init__(self, inp, oup, stride=1, expand_ratio=4):
+ super(MV2Block, self).__init__()
+ assert stride in [1, 2]
+
+ hidden_dim = round(inp * expand_ratio)
+ self.identity = stride == 1 and inp == oup
+
+ if expand_ratio == 1:
+ self.conv = nn.Sequential(
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.SiLU(),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+ else:
+ self.conv = nn.Sequential(
+ # pw
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.SiLU(),
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.SiLU(),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+ def forward(self, x):
+ if self.identity:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+class MobileViTBlock(nn.Module):
+ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
+ super().__init__()
+ self.ph, self.pw = patch_size
+
+ self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
+ self.conv2 = conv_1x1_bn(channel, dim)
+
+ self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
+
+ self.conv3 = conv_1x1_bn(dim, channel)
+ self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
+
+ def forward(self, x):
+ y = x.clone()
+
+ # Local representations
+ x = self.conv1(x)
+ x = self.conv2(x)
+
+ # Global representations
+ _, _, h, w = x.shape
+ x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
+ x = self.transformer(x)
+ x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
+ pw=self.pw)
+
+ # Fusion
+ x = self.conv3(x)
+ x = torch.cat((x, y), 1)
+ x = self.conv4(x)
+ return x
+
+
+class MobileViT(nn.Module):
+ def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
+ super().__init__()
+ ih, iw = image_size
+ ph, pw = patch_size
+ assert ih % ph == 0 and iw % pw == 0
+
+ L = [2, 4, 3]
+
+ self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
+
+ self.mv2 = nn.ModuleList([])
+ self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
+ self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
+ self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
+ self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
+ self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
+ self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
+ self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
+
+ self.mvit = nn.ModuleList([])
+ self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
+ self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
+ self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
+
+ self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
+
+ self.pool = nn.AvgPool2d(ih // 32, 1)
+ self.fc = nn.Linear(channels[-1], num_classes, bias=False)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.mv2[0](x)
+
+ x = self.mv2[1](x)
+ x = self.mv2[2](x)
+ x = self.mv2[3](x)
+
+ x = self.mv2[4](x)
+ x = self.mvit[0](x)
+
+ x = self.mv2[5](x)
+ x = self.mvit[1](x)
+
+ x = self.mv2[6](x)
+ x = self.mvit[2](x)
+ x = self.conv2(x)
+
+ x = self.pool(x).view(-1, x.shape[1])
+ x = self.fc(x)
+ return x
+