Skip to content

Commit

Permalink
add EsViT, by popular request, an alternative to Dino that is compati…
Browse files Browse the repository at this point in the history
…ble with efficient ViTs with accounting for regional self-supervised loss
  • Loading branch information
lucidrains committed May 3, 2022
1 parent c2aab05 commit 70284c0
Show file tree
Hide file tree
Showing 5 changed files with 448 additions and 4 deletions.
75 changes: 75 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
- [Parallel ViT](#parallel-vit)
- [Learnable Memory ViT](#learnable-memory-vit)
- [Dino](#dino)
- [EsViT](#esvit)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
* [Efficient Attention](#efficient-attention)
Expand Down Expand Up @@ -1076,6 +1077,80 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```

## EsViT

<img src="./images/esvit.png" width="350px"></img>

<a href="https://arxiv.org/abs/2106.09785">`EsViT`</a> is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput.

Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.

```python
import torch
from vit_pytorch.cvt import CvT
from vit_pytorch.es_vit import EsViTTrainer

cvt = CvT(
num_classes = 1000,
s1_emb_dim = 64,
s1_emb_kernel = 7,
s1_emb_stride = 4,
s1_proj_kernel = 3,
s1_kv_proj_stride = 2,
s1_heads = 1,
s1_depth = 1,
s1_mlp_mult = 4,
s2_emb_dim = 192,
s2_emb_kernel = 3,
s2_emb_stride = 2,
s2_proj_kernel = 3,
s2_kv_proj_stride = 2,
s2_heads = 3,
s2_depth = 2,
s2_mlp_mult = 4,
s3_emb_dim = 384,
s3_emb_kernel = 3,
s3_emb_stride = 2,
s3_proj_kernel = 3,
s3_kv_proj_stride = 2,
s3_heads = 4,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
)

learner = EsViTTrainer(
cvt,
image_size = 256,
hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
return torch.randn(8, 3, 256, 256)

for _ in range(1000):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(cvt.state_dict(), './pretrained-net.pt')
```

## Accessing Attention

If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
Expand Down
Binary file added images/esvit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.33.2',
version = '0.34.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
8 changes: 5 additions & 3 deletions vit_pytorch/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ def __init__(

dim = config['emb_dim']

self.layers = nn.Sequential(
*layers,
self.layers = nn.Sequential(*layers)

self.to_logits = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)

def forward(self, x):
return self.layers(x)
latents = self.layers(x)
return self.to_logits(latents)
Loading

0 comments on commit 70284c0

Please sign in to comment.