Skip to content

Commit

Permalink
Merge pull request #181 from murufeng/main
Browse files Browse the repository at this point in the history
Add MobileViT
  • Loading branch information
lucidrains authored Dec 21, 2021
2 parents f4b0b14 + 89d3a04 commit 86a7302
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 0 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- [CrossFormer](#crossformer)
- [RegionViT](#regionvit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
Expand Down Expand Up @@ -549,6 +550,31 @@ img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```

## MobileViT

<img src="./images/mbvit.png" width="400px"></img>

This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
perspective for the global processing of information with transformers.

You can use it with the following code (ex. mobilevit_xs)

```
import torch
from vit_pytorch.mobile_vit import MobileViT
mbvit_xs = MobileViT(
image_size=(256, 256),
dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
num_classes = 1000
)
img = torch.randn(1, 3, 256, 256)
pred = mbvit_xs(img) # (1, 1000)
```

## Simple Masked Image Modeling

<img src="./images/simmim.png" width="400px"/>
Expand Down
Binary file added images/mbvit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
229 changes: 229 additions & 0 deletions vit_pytorch/mobile_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""
An implementation of MobileViT Model as defined in:
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
Arxiv: https://arxiv.org/abs/2110.02178
Origin Code: https://github.com/murufeng/awesome_lightweight_networks
"""

import torch
import torch.nn as nn

from einops import rearrange

def _make_divisible(v, divisor, min_value=None):

if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v


def Conv_BN_ReLU(inp, oup, kernel, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)


def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.ffn(x)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x

class MV2Block(nn.Module):
def __init__(self, inp, oup, stride=1, expand_ratio=4):
super(MV2Block, self).__init__()
assert stride in [1, 2]

hidden_dim = round(inp * expand_ratio)
self.identity = stride == 1 and inp == oup

if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)

def forward(self, x):
if self.identity:
return x + self.conv(x)
else:
return self.conv(x)

class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size

self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)

self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)

self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)

def forward(self, x):
y = x.clone()

# Local representations
x = self.conv1(x)
x = self.conv2(x)

# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
pw=self.pw)

# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x


class MobileViT(nn.Module):
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
super().__init__()
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0

L = [2, 4, 3]

self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)

self.mv2 = nn.ModuleList([])
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))

self.mvit = nn.ModuleList([])
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))

self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

self.pool = nn.AvgPool2d(ih // 32, 1)
self.fc = nn.Linear(channels[-1], num_classes, bias=False)

def forward(self, x):
x = self.conv1(x)
x = self.mv2[0](x)

x = self.mv2[1](x)
x = self.mv2[2](x)
x = self.mv2[3](x)

x = self.mv2[4](x)
x = self.mvit[0](x)

x = self.mv2[5](x)
x = self.mvit[1](x)

x = self.mv2[6](x)
x = self.mvit[2](x)
x = self.conv2(x)

x = self.pool(x).view(-1, x.shape[1])
x = self.fc(x)
return x

0 comments on commit 86a7302

Please sign in to comment.