diff --git a/scripts/clipav1_vit_l16_i37_t8.sh b/scripts/clipav1_vit_l16_i37_t8.sh new file mode 100644 index 000000000..d3ff0901e --- /dev/null +++ b/scripts/clipav1_vit_l16_i37_t8.sh @@ -0,0 +1,6 @@ +# eval on a single gpu +CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ + --model ViT-L-16-CL32-GAP \ + --pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ + --seed 0 \ + --imagenet-val '/path/to/ImageNet/val' \ No newline at end of file diff --git a/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh b/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh new file mode 100644 index 000000000..7f22386c3 --- /dev/null +++ b/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ + --model ViT-H-14-CL32-GAP-BigVision \ + --pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ + --force-image-size 336 \ + --square-resize-only \ + --interpolation 'bilinear' \ + --image-mean 0.485 0.456 0.406 \ + --image-std 0.229 0.224 0.225 \ + --seed 0 \ + --imagenet-val '/path/to/ImageNet/val' diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 12f3dec30..9b65c2f17 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -19,7 +19,7 @@ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform, AugmentationCfg -from .tokenizer import HFTokenizer, tokenize, syntax_mask_tokenize, random_mask_tokenize, block_mask_tokenize +from .tokenizer import HFTokenizer, tokenize, syntax_mask_tokenize, random_mask_tokenize, block_mask_tokenize, get_pp_bert_tokenize HF_HUB_PREFIX = 'hf-hub:' @@ -88,11 +88,18 @@ def get_tokenizer(model_name): tokenizer = random_mask_tokenize elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'block': tokenizer = block_mask_tokenize + elif 'bert_tokenizer' in config['text_cfg'] and config['text_cfg']['bert_tokenizer']: + tokenizer = get_pp_bert_tokenize( + vocab_path=config['text_cfg']['vocab_path'], + max_len=config['text_cfg']['context_length'], + ) else: tokenizer = tokenize + if 'context_length' in config['text_cfg'].keys(): context_length = config['text_cfg']['context_length'] tokenizer = partial(tokenizer, context_length=context_length) + return tokenizer @@ -141,6 +148,7 @@ def create_model( cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, + pos_embed: str = None, **model_kwargs, ): has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) @@ -190,6 +198,10 @@ def create_model( # override model config's image size model_cfg["vision_cfg"]["image_size"] = force_image_size + if pos_embed is not None: + # override model config's positional embedding + model_cfg["vision_cfg"]["pos_embed"] = pos_embed + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: if is_timm_model: @@ -332,6 +344,9 @@ def create_model_and_transforms( aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, + pos_embed: str = None, + interpolation: str = 'bicubic', # only effective for inference + square_resize_only: bool = False, # only effective for inference **model_kwargs, ): model = create_model( @@ -348,6 +363,7 @@ def create_model_and_transforms( pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, + pos_embed=pos_embed, **model_kwargs, ) @@ -365,6 +381,8 @@ def create_model_and_transforms( is_train=False, mean=image_mean, std=image_std, + interpolation=interpolation, + square_resize_only=square_resize_only, ) return model, preprocess_train, preprocess_val diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 0ccf01bca..df9a38ad7 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint +from functools import partial from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet @@ -45,6 +46,10 @@ class CLIPVisionCfg: timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth + pos_embed: str = 'learnable' + act_kwargs: dict = None + ln_pre: bool = True + pool_style: str = 'open_clip' @dataclass @@ -63,6 +68,11 @@ class CLIPTextCfg: embed_cls: bool = False pad_id: int = 0 output_tokens: bool = False + act_kwargs: dict = None + pool_style: str = 'open_clip' + bert_tokenizer: bool = False + vocab_path: str = None + attention_mask: bool = True text_mask: str = 'first' # default first truncate in bpe_tokenizer @@ -123,6 +133,8 @@ def _build_vision_tower( else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if vision_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **vision_cfg.act_kwargs) visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, @@ -141,6 +153,9 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, + pos_embed=vision_cfg.pos_embed, + ln_pre=vision_cfg.ln_pre, + pool_style=vision_cfg.pool_style, ) return visual @@ -168,6 +183,9 @@ def _build_text_tower( act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if text_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **text_cfg.act_kwargs) + text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, @@ -181,6 +199,8 @@ def _build_text_tower( pad_id=text_cfg.pad_id, act_layer=act_layer, norm_layer=norm_layer, + pool_style=text_cfg.pool_style, + attention_mask=text_cfg.attention_mask, ) return text @@ -211,6 +231,7 @@ def __init__( self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection + self.pool_style = text.pool_style self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) @@ -243,7 +264,16 @@ def encode_text(self, text, normalize: bool = False): x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + if self.pool_style == 'open_clip': + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + elif self.pool_style == 'big_vision_tok': + pooled = x[:, 0] + x = pooled @ self.text_projection + elif self.pool_style == 'big_vision_last': + pooled = x[:, -1] + x = pooled @ self.text_projection + else: + raise ValueError return F.normalize(x, dim=-1) if normalize else x def forward( diff --git a/src/open_clip/model_configs/ViT-H-14-CL32-GAP-BigVision.json b/src/open_clip/model_configs/ViT-H-14-CL32-GAP-BigVision.json new file mode 100644 index 000000000..1b3004d62 --- /dev/null +++ b/src/open_clip/model_configs/ViT-H-14-CL32-GAP-BigVision.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "act_kwargs":{ + "approximate": "tanh" + }, + "ln_pre": false, + "pool_style": "big_vision_gap", + "global_average_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "bert_tokenizer": true, + "vocab_path": "gs://vit_models/lit/bert/uncased_L-12_H-768_A-12/vocab.txt", + "width": 1024, + "heads": 16, + "layers": 24, + "act_kwargs":{ + "approximate": "tanh" + }, + "pool_style": "big_vision_last", + "attention_mask": false + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-16-CL32-GAP.json b/src/open_clip/model_configs/ViT-L-16-CL32-GAP.json new file mode 100644 index 000000000..270a4607b --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-16-CL32-GAP.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16, + "global_average_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/pos_embed.py b/src/open_clip/pos_embed.py new file mode 100644 index 000000000..5c8082b34 --- /dev/null +++ b/src/open_clip/pos_embed.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 3e651aed5..8f3acba13 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -13,6 +13,10 @@ import torch import numpy as np +import tensorflow as tf +import tensorflow_text +tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + # https://stackoverflow.com/q/62691279 import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -161,10 +165,12 @@ def decode(self, tokens): _tokenizer = SimpleTokenizer() + def decode(output_ids: torch.Tensor): output_ids = output_ids.cpu().numpy() return _tokenizer.decode(output_ids) + def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -350,4 +356,55 @@ def get_order(x): tokens[-1] = eot_token result[i, :len(tokens)] = torch.tensor(tokens) - return result \ No newline at end of file + return result + + +def _create_bert_tokenizer(vocab_path): + with tf.io.gfile.GFile(vocab_path) as f: + vocab = f.read().split("\n") + cls_token = vocab.index("[CLS]") + return cls_token, tensorflow_text.BertTokenizer( + vocab_path, + token_out_type=tf.int32, + lower_case=True, + ) + + +def get_pp_bert_tokenize(vocab_path, max_len): + """Extracts tokens with tensorflow_text.BertTokenizer. + copied from big_vision. modified to deal with multiple text + Args: + vocab_path: Path to a file containing the vocabulry for the WordPiece + tokenizer. It's the "vocab.txt" file in the zip file downloaded from + the original repo https://github.com/google-research/bert + max_len: Number of tokens after tokenization. + sample_if_multi: Whether the first text should be taken (if set to `False`), + or whether a random text should be tokenized. + + Returns: + A preprocessing Op. + """ + cls_token, tokenizer = _create_bert_tokenizer(vocab_path) + + def _pp_bert_tokenize(labels): + if isinstance(labels, str): + labels = [labels] + + labels = tf.reshape(labels, (-1,)) + output_list = [] + for i in range(tf.shape(labels)[0]): + txt = labels[i] + + token_ids = tokenizer.tokenize(txt[None]) + padded_token_ids, mask = tensorflow_text.pad_model_inputs( + token_ids, max_len - 1) + del mask # Recovered from zero padding in model. + count = tf.shape(padded_token_ids)[0] + padded_token_ids = tf.concat( + [tf.fill([count, 1], cls_token), padded_token_ids], axis=1) + output = padded_token_ids[0].numpy() + output_list.append(output[None, :]) + output = np.concatenate(output_list, axis=0) + return torch.tensor(output) + + return _pp_bert_tokenize diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 59f13bb59..b61d2f0e1 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -97,6 +97,8 @@ def image_transform( resize_longest_max: bool = False, fill_color: int = 0, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + interpolation: str = 'bicubic', # only effective for inference + square_resize_only: bool = False, ): mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): @@ -132,6 +134,8 @@ def image_transform( # drop extra item aug_cfg_dict.pop('color_jitter_prob', False) aug_cfg_dict.pop('gray_scale_prob', False) + aug_cfg_dict.pop('interpolation', False) + aug_cfg_dict.pop('square_resize_only', False) train_transform = create_transform( input_size=input_size, @@ -169,13 +173,21 @@ def image_transform( warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') return train_transform else: + assert interpolation in ['bicubic', 'bilinear'] + assert not (resize_longest_max and square_resize_only) if resize_longest_max: transforms = [ ResizeMaxSize(image_size, fill=fill_color) ] + elif square_resize_only: + if isinstance(image_size, int): + image_size = (image_size, image_size) + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC if interpolation == 'bicubic' else InterpolationMode.BILINEAR), + ] else: transforms = [ - Resize(image_size, interpolation=InterpolationMode.BICUBIC), + Resize(image_size, interpolation=InterpolationMode.BICUBIC if interpolation == 'bicubic' else InterpolationMode.BILINEAR), CenterCrop(image_size), ] transforms.extend([ diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 0a30e9466..b093eddcd 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,6 +1,7 @@ from collections import OrderedDict import math from typing import Callable, Optional, Sequence, Tuple +from functools import partial import torch from torch import nn @@ -8,6 +9,7 @@ from torch.utils.checkpoint import checkpoint from .utils import to_2tuple +from .pos_embed import get_2d_sincos_pos_embed class LayerNormFp32(nn.LayerNorm): @@ -343,7 +345,10 @@ def __init__( input_patchnorm: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - output_tokens: bool = False + output_tokens: bool = False, + pos_embed: str = 'learnable', + ln_pre: bool = True, + pool_style: str = 'open_clip', # only effective when attention_pool is None ): super().__init__() self.output_tokens = output_tokens @@ -366,12 +371,21 @@ def __init__( # class embeddings and positional embeddings scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + if pos_embed == 'learnable': + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + elif pos_embed == 'sin_cos_2d': + # fixed sin-cos embedding + assert self.grid_size[0] == self.grid_size[1], 'currently sin cos 2d pos embedding only supports square input' + self.positional_embedding = nn.Parameter(torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) + pos_embed = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) + self.positional_embedding.data.copy_(torch.from_numpy(pos_embed).float()) + else: + raise ValueError # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() - self.ln_pre = norm_layer(width) + self.ln_pre = norm_layer(width) if ln_pre else nn.Identity() self.transformer = Transformer( width, layers, @@ -392,6 +406,8 @@ def __init__( self.ln_post = norm_layer(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + self.pool_style = pool_style + self.init_parameters() def lock(self, unlocked_groups=0, freeze_bn_stats=False): @@ -451,8 +467,10 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable - def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.global_average_pool: + def _global_pool(self, x: torch.Tensor, include_cls=True) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool and not include_cls: + return x[:, 1:].mean(dim=1), x[:, 1:] + elif self.global_average_pool and include_cls: return x.mean(dim=1), x else: return x[:, 0], x[:, 1:] @@ -491,8 +509,19 @@ def forward(self, x: torch.Tensor): x = self.ln_post(x) pooled, tokens = self._global_pool(x) else: - pooled, tokens = self._global_pool(x) - pooled = self.ln_post(pooled) + if self.pool_style == 'open_clip': + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + elif self.pool_style == 'big_vision_tok': + assert not self.global_average_pool + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + elif self.pool_style == 'big_vision_gap': + assert self.global_average_pool + pooled, tokens = self._global_pool(x, include_cls=False) + pooled = self.ln_post(pooled) + else: + raise ValueError if self.proj is not None: pooled = pooled @ self.proj @@ -520,6 +549,8 @@ def __init__( embed_cls: bool = False, pad_id: int = 0, output_tokens: bool = False, + pool_style: str = 'open_clip', + attention_mask: bool = True, ): super().__init__() self.output_tokens = output_tokens @@ -537,6 +568,9 @@ def __init__( self.num_pos += 1 else: self.cls_emb = None + self.pool_style = pool_style + if self.pool_style == 'big_vision': + assert not embed_cls, 'bert tokenizer in big_vision already append a cls token, so do not use cls_embed in text transformer!' self.token_embedding = nn.Embedding(vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) @@ -550,7 +584,10 @@ def __init__( ) self.ln_final = norm_layer(width) - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + if attention_mask: + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + else: + self.attn_mask = None self.init_parameters() @@ -619,8 +656,19 @@ def forward(self, text): pooled, tokens = x[:, -1], x[:, :-1] pooled = self.ln_final(pooled) else: - x = self.ln_final(x) - pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + if self.pool_style == 'open_clip': + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + elif self.pool_style == 'big_vision_tok': + x = self.ln_final(x) + # not sure what is the tokens here + pooled, tokens = x[:, 0], x + elif self.pool_style == 'big_vision_last': + x = self.ln_final(x) + # not sure what is the tokens here + pooled, tokens = x[:, -1], x + else: + raise ValueError if self.text_projection is not None: pooled = pooled @ self.text_projection diff --git a/src/training/main.py b/src/training/main.py index 08d2412e2..0c71dab0b 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -234,6 +234,9 @@ def main(args): image_std=args.image_std, aug_cfg=args.aug_cfg, output_dict=True, + pos_embed=args.pos_embed, + interpolation=args.interpolation, # only effective for inference + square_resize_only=args.square_resize_only, # only effective for inference **model_kwargs, ) if args.distill: diff --git a/src/training/params.py b/src/training/params.py index 345382e57..cc1b888a3 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -442,6 +442,22 @@ def parse_args(args): action="store_true", help='Use SigLip (sigmoid) loss.' ) + parser.add_argument( + '--pos-embed', + default='learnable', type=str, + help="type of positional embedding in vision transformer. support learnable and sin_cos_2d" + ) + parser.add_argument( + '--interpolation', + default='bicubic', type=str, choices=['bicubic', 'bilinear'], + help="resize interpolation during inference" + ) + parser.add_argument( + '--square-resize-only', + default=False, action='store_true', + help="square resize during inference" + ) + args = parser.parse_args(args) # If some params are not passed, we use the default values based on model name.