From 9e3fec23984b8189d31f628076765bd0d28a1ae5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 28 Jun 2023 08:02:43 -0700 Subject: [PATCH] fix mpp --- setup.py | 2 +- vit_pytorch/mpp.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 0b043043..535f1ff5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.2.2', + version = '1.2.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/mpp.py b/vit_pytorch/mpp.py index e8b044d1..754d50ce 100644 --- a/vit_pytorch/mpp.py +++ b/vit_pytorch/mpp.py @@ -96,6 +96,9 @@ def __init__( self.loss = MPPLoss(patch_size, channels, output_channel_bits, max_pixel_val, mean, std) + # extract patching function + self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:]) + # output transformation self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) @@ -151,7 +154,7 @@ def forward(self, input, **kwargs): masked_input[bool_mask_replace] = self.mask_token # linear embedding of patches - masked_input = transformer.to_patch_embedding[-1](masked_input) + masked_input = self.patch_to_emb(masked_input) # add cls token to input sequence b, n, _ = masked_input.shape