diff --git a/src/open_clip/big_vision.py b/src/open_clip/convert.py similarity index 77% rename from src/open_clip/big_vision.py rename to src/open_clip/convert.py index 0d7eaf3fa..0bfe35112 100644 --- a/src/open_clip/big_vision.py +++ b/src/open_clip/convert.py @@ -1,7 +1,11 @@ +""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats. +""" +from typing import Union + import torch import numpy as np -from .model import CustomTextCLIP +from .model import CLIP, CustomTextCLIP from .transformer import TextTransformer, Transformer @@ -134,3 +138,48 @@ def _convert_openclip_txt(module: TextTransformer, prefix): model.logit_scale.copy_(_n2p(w['params/t'])[0]) +@torch.no_grad() +def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict): + from timm.models.fastvit import _checkpoint_filter_fn + + def _convert_timm_img(state_dict, prefix='image_encoder.'): + timm_state_dict = _checkpoint_filter_fn(state_dict, model.visual.trunk) + timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} + return timm_state_dict + + def _convert_openclip_txt(state_dict, prefix='text_encoder.'): + text_dict = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + k = k.replace('projection_layer', 'text_projection') + k = k.replace('embedding_layer', 'token_embedding') + if k.startswith('positional_embedding.pos_embed.pos_embed'): + k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') + v = v.squeeze() + k = k.replace('final_layer_norm', 'ln_final') + k = k.replace('pre_norm_mha.0', 'ln_1') + k = k.replace('pre_norm_mha.1', 'attn') + k = k.replace('pre_norm_ffn.0', 'ln_2') + k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') + k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') + k = k.replace('qkv_proj.weight', 'in_proj_weight') + k = k.replace('qkv_proj.bias', 'in_proj_bias') + k = k.replace('transformer.', 'transformer.resblocks.') + text_dict['text.' + k] = v + return text_dict + + image_dict = _convert_timm_img(state_dict) + text_dict = _convert_openclip_txt(state_dict) + out_dict = {**image_dict, **text_dict} + out_dict['logit_scale'] = state_dict['logit_scale'] + return out_dict + + +def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): + if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: + # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) + state_dict = convert_mobile_clip_state_dict(model, state_dict) + + return state_dict diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index cf62d5a1c..86b44862f 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -10,6 +10,7 @@ import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .convert import convert_state_dict from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .coca_model import CoCa @@ -139,25 +140,39 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): return state_dict -def load_checkpoint(model, checkpoint_path, strict=True): +def load_checkpoint( + model: Union[CLIP, CustomTextCLIP], + checkpoint_path: str, + strict: bool = True, +): if Path(checkpoint_path).suffix in ('.npz', '.npy'): - from .big_vision import load_big_vision_weights + # Separate path loading numpy big_vision (SigLIP) weights + from open_clip.convert import load_big_vision_weights load_big_vision_weights(model, checkpoint_path) return {} state_dict = load_state_dict(checkpoint_path) - # detect old format and make compatible with new format + + # Detect & convert 3rd party state_dicts -> open_clip + state_dict = convert_state_dict(model, state_dict) + + # Detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 if 'logit_bias' not in state_dict and model.logit_bias is not None: state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) + # Certain text transformers no longer expect position_ids after transformers==4.31 position_id_key = 'text.transformer.embeddings.position_ids' if position_id_key in state_dict and not hasattr(model, position_id_key): del state_dict[position_id_key] + resize_pos_embed(state_dict, model) resize_text_pos_embed(state_dict, model) + + # Finally, load the massaged state_dict into model incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys diff --git a/src/open_clip/model_configs/mobileclip_s1.json b/src/open_clip/model_configs/mobileclip_s1.json new file mode 100644 index 000000000..80780c5ea --- /dev/null +++ b/src/open_clip/model_configs/mobileclip_s1.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "fastvit_mci1", + "timm_model_pretrained": false, + "timm_pool": "avg", + "timm_proj": null, + "timm_drop": 0.0, + "timm_drop_path": 0.0, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "no_causal_mask": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/mobileclip_s2.json b/src/open_clip/model_configs/mobileclip_s2.json new file mode 100644 index 000000000..66ebc16aa --- /dev/null +++ b/src/open_clip/model_configs/mobileclip_s2.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "fastvit_mci2", + "timm_model_pretrained": false, + "timm_pool": "avg", + "timm_proj": null, + "timm_drop": 0.0, + "timm_drop_path": 0.0, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "no_causal_mask": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index e43e773fd..399f9bdae 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -65,6 +65,20 @@ def _apcfg(url='', hf_hub='', **kwargs): } +def _mccfg(url='', hf_hub='', **kwargs): + # MobileCLIP + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': (0., 0., 0.), + 'std': (1., 1., 1.), + 'interpolation': 'bilinear', + 'resize_mode': 'shortest', + **kwargs, + } + + + _RN50 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), @@ -438,7 +452,12 @@ def _apcfg(url='', hf_hub='', **kwargs): "nllb-clip-large-siglip": dict( v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), - ) + ), + + "mobileclip_s1": dict( + datacompdr=_mccfg(url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt')), + "mobileclip_s2": dict( + datacompdr=_mccfg(url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt')) }