diff --git a/setup.py b/setup.py index 0b04304..535f1ff 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.2.2', + version = '1.2.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/mpp.py b/vit_pytorch/mpp.py index e8b044d..754d50c 100644 --- a/vit_pytorch/mpp.py +++ b/vit_pytorch/mpp.py @@ -96,6 +96,9 @@ def __init__( self.loss = MPPLoss(patch_size, channels, output_channel_bits, max_pixel_val, mean, std) + # extract patching function + self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:]) + # output transformation self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) @@ -151,7 +154,7 @@ def forward(self, input, **kwargs): masked_input[bool_mask_replace] = self.mask_token # linear embedding of patches - masked_input = transformer.to_patch_embedding[-1](masked_input) + masked_input = self.patch_to_emb(masked_input) # add cls token to input sequence b, n, _ = masked_input.shape