Skip to content

Commit

Permalink
one can pass a callback to token_dropout_prob for NaViT that takes in…
Browse files Browse the repository at this point in the history
… height and width and calculate appropriate dropout rate
  • Loading branch information
lucidrains committed Jul 24, 2023
1 parent 17675e0 commit cd21090
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.7',
version = '1.2.8',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
23 changes: 18 additions & 5 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,25 @@ def forward(
return self.norm(x)

class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
super().__init__()
image_height, image_width = pair(image_size)

# what percent of tokens to dropout
# in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?)
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width

self.token_dropout_prob = token_dropout_prob
self.calc_token_dropout = calc_token_dropout = None

if callable(token_dropout_prob):
self.calc_token_dropout = token_dropout_prob

elif isinstance(token_dropout_prob, (float, int)):
assert 0. < token_dropout_prob < 1.
token_dropout_prob = float(token_dropout_prob)
self.calc_token_dropout = lambda height, width: token_dropout_prob

# calculate patching related stuff

assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

Expand Down Expand Up @@ -190,7 +201,7 @@ def forward(
self,
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, self.token_dropout_prob > 0.
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)

arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
Expand Down Expand Up @@ -227,8 +238,10 @@ def forward(
seq_len = seq.shape[-2]

if has_token_dropout:
num_keep = max(1, int(seq_len * (1 - self.token_dropout_prob)))
token_dropout = self.calc_token_dropout(*image_dims)
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices

seq = seq[keep_indices]
pos = pos[keep_indices]

Expand Down

0 comments on commit cd21090

Please sign in to comment.