Skip to content

Commit

Permalink
fix mpp
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 28, 2023
1 parent ce4bcd0 commit 9e3fec2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.2',
version = '1.2.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
5 changes: 4 additions & 1 deletion vit_pytorch/mpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val, mean, std)

# extract patching function
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])

# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))

Expand Down Expand Up @@ -151,7 +154,7 @@ def forward(self, input, **kwargs):
masked_input[bool_mask_replace] = self.mask_token

# linear embedding of patches
masked_input = transformer.to_patch_embedding[-1](masked_input)
masked_input = self.patch_to_emb(masked_input)

# add cls token to input sequence
b, n, _ = masked_input.shape
Expand Down

0 comments on commit 9e3fec2

Please sign in to comment.