From dbb7bd179bbac25d7fb8ff63e6017f0181e3daf3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 21 Dec 2021 10:20:02 -0800 Subject: [PATCH] release MobileViT, from @murufeng --- README.md | 15 ++++- setup.py | 2 +- vit_pytorch/mobile_vit.py | 114 +++++++++++++++++++++----------------- 3 files changed, 76 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 9e9b57ea..a2ad32f4 100644 --- a/README.md +++ b/README.md @@ -559,12 +559,12 @@ perspective for the global processing of information with transformers. You can use it with the following code (ex. mobilevit_xs) -``` +```python import torch from vit_pytorch.mobile_vit import MobileViT mbvit_xs = MobileViT( - image_size=(256, 256), + image_size = (256, 256), dims = [96, 120, 144], channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], num_classes = 1000 @@ -1190,6 +1190,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{mehta2021mobilevit, + title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer}, + author = {Sachin Mehta and Mohammad Rastegari}, + year = {2021}, + eprint = {2110.02178}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/setup.py b/setup.py index 889a3124..64db8479 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.24.3', + version = '0.25.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/mobile_vit.py b/vit_pytorch/mobile_vit.py index bb44d02b..b8b72537 100644 --- a/vit_pytorch/mobile_vit.py +++ b/vit_pytorch/mobile_vit.py @@ -9,9 +9,9 @@ import torch.nn as nn from einops import rearrange +from einops.layers.torch import Reduce def _make_divisible(v, divisor, min_value=None): - if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) @@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None): return new_v -def Conv_BN_ReLU(inp, oup, kernel, stride=1): +def conv_bn_relu(inp, oup, kernel, stride=1): return nn.Sequential( nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False), nn.BatchNorm2d(oup), @@ -63,8 +63,6 @@ class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) - self.heads = heads self.scale = dim_head ** -0.5 @@ -74,7 +72,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.): self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) - ) if project_out else nn.Identity() + ) def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) @@ -96,6 +94,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) + def forward(self, x): for attn, ff in self.layers: x = attn(x) + x @@ -136,23 +135,24 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4): ) def forward(self, x): + out = self.conv(x) + if self.identity: - return x + self.conv(x) - else: - return self.conv(x) + out = out + x + return out class MobileViTBlock(nn.Module): def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): super().__init__() self.ph, self.pw = patch_size - self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size) + self.conv1 = conv_bn_relu(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout) self.conv3 = conv_1x1_bn(dim, channel) - self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size) + self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() @@ -165,8 +165,7 @@ def forward(self, x): _, _, h, w = x.shape x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) x = self.transformer(x) - x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, - pw=self.pw) + x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) @@ -176,54 +175,65 @@ def forward(self, x): class MobileViT(nn.Module): - def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): + def __init__( + self, + image_size, + dims, + channels, + num_classes, + expansion = 4, + kernel_size = 3, + patch_size = (2, 2), + depths = (2, 4, 3) + ): super().__init__() + assert len(dims) == 3, 'dims must be a tuple of 3' + assert len(depths) == 3, 'depths must be a tuple of 3' + ih, iw = image_size ph, pw = patch_size assert ih % ph == 0 and iw % pw == 0 - L = [2, 4, 3] - - self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2) - - self.mv2 = nn.ModuleList([]) - self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) - self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) - self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) - self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) - self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) - self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) - self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) - - self.mvit = nn.ModuleList([]) - self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) - self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) - self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) - - self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) - - self.pool = nn.AvgPool2d(ih // 32, 1) - self.fc = nn.Linear(channels[-1], num_classes, bias=False) + init_dim, *_, last_dim = channels + + self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2) + + self.stem = nn.ModuleList([]) + self.stem.append(MV2Block(channels[0], channels[1], 1, expansion)) + self.stem.append(MV2Block(channels[1], channels[2], 2, expansion)) + self.stem.append(MV2Block(channels[2], channels[3], 1, expansion)) + self.stem.append(MV2Block(channels[2], channels[3], 1, expansion)) + + self.trunk = nn.ModuleList([]) + self.trunk.append(nn.ModuleList([ + MV2Block(channels[3], channels[4], 2, expansion), + MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)) + ])) + + self.trunk.append(nn.ModuleList([ + MV2Block(channels[5], channels[6], 2, expansion), + MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)) + ])) + + self.trunk.append(nn.ModuleList([ + MV2Block(channels[7], channels[8], 2, expansion), + MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)) + ])) + + self.to_logits = nn.Sequential( + conv_1x1_bn(channels[-2], last_dim), + Reduce('b c h w -> b c', 'mean'), + nn.Linear(channels[-1], num_classes, bias=False) + ) def forward(self, x): x = self.conv1(x) - x = self.mv2[0](x) - - x = self.mv2[1](x) - x = self.mv2[2](x) - x = self.mv2[3](x) - x = self.mv2[4](x) - x = self.mvit[0](x) + for conv in self.stem: + x = conv(x) - x = self.mv2[5](x) - x = self.mvit[1](x) - - x = self.mv2[6](x) - x = self.mvit[2](x) - x = self.conv2(x) - - x = self.pool(x).view(-1, x.shape[1]) - x = self.fc(x) - return x + for conv, attn in self.trunk: + x = conv(x) + x = attn(x) + return self.to_logits(x)