diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 5da4179c0..d1b9d7be4 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -11,80 +11,57 @@ from .coca_model import CoCa from .loss import ClipLoss, CoCaLoss, DistillClipLoss, SigLipLoss -from .model import ( - CLIP, - CustomTextCLIP, - convert_to_custom_text_state_dict, - convert_weights_to_lp, - get_cast_dtype, - resize_pos_embed, - resize_text_pos_embed, - set_model_preprocess_cfg, -) +from .model import (CLIP, CustomTextCLIP, convert_to_custom_text_state_dict, + convert_weights_to_lp, get_cast_dtype, resize_pos_embed, + resize_text_pos_embed, set_model_preprocess_cfg) from .openai import load_openai_model -from .pretrained import ( - download_pretrained, - download_pretrained_from_hf, - get_pretrained_cfg, - list_pretrained_tags_by_model, -) -from .tokenizer import ( - DEFAULT_CONTEXT_LENGTH, - HFTokenizer, - NLLBTokenizer, - SimpleTokenizer, -) -from .transform import ( - AugmentationCfg, - PreprocessCfg, - image_transform_v2, - merge_preprocess_dict, - merge_preprocess_kwargs, -) - -HF_HUB_PREFIX = "hf-hub:" +from .pretrained import (download_pretrained, download_pretrained_from_hf, + get_pretrained_cfg, list_pretrained_tags_by_model) +from .tokenizer import (DEFAULT_CONTEXT_LENGTH, HFTokenizer, NLLBTokenizer, + SimpleTokenizer) +from .transform import (AugmentationCfg, PreprocessCfg, image_transform_v2, + merge_preprocess_dict, merge_preprocess_kwargs) + +HF_HUB_PREFIX = 'hf-hub:' _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS - config_ext = (".json",) + config_ext = ('.json',) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: - config_files.extend(config_path.glob(f"*{ext}")) + config_files.extend(config_path.glob(f'*{ext}')) for cf in config_files: - with open(cf, "r") as f: + with open(cf, 'r') as f: model_cfg = json.load(f) - if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): _MODEL_CONFIGS[cf.stem] = model_cfg - _MODEL_CONFIGS = { - k: v - for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) - } + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} _rescan_model_configs() # initial populate of model config registry def list_models(): - """enumerate available model architectures based on config files""" + """ enumerate available model architectures based on config files """ return list(_MODEL_CONFIGS.keys()) def add_model_config(path): - """add model config path or file and update registry""" + """ add model config path or file and update registry """ if not isinstance(path, Path): path = Path(path) _MODEL_CONFIG_PATHS.append(path) @@ -99,23 +76,21 @@ def get_model_config(model_name): def _get_hf_config(model_id, cache_dir=None): - config_path = download_pretrained_from_hf( - model_id, filename="open_clip_config.json", cache_dir=cache_dir - ) - with open(config_path, "r", encoding="utf-8") as f: + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config def get_tokenizer( - model_name: str = "", - context_length: Optional[int] = None, - **kwargs, + model_name: str = '', + context_length: Optional[int] = None, + **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): - model_name = model_name[len(HF_HUB_PREFIX) :] + model_name = model_name[len(HF_HUB_PREFIX):] try: - config = _get_hf_config(model_name)["model_cfg"] + config = _get_hf_config(model_name)['model_cfg'] except Exception: tokenizer = HFTokenizer( model_name, @@ -127,25 +102,25 @@ def get_tokenizer( config = get_model_config(model_name) assert config is not None, f"No valid model config found for {model_name}." - text_config = config.get("text_cfg", {}) - if "tokenizer_kwargs" in text_config: - tokenizer_kwargs = dict(text_config["tokenizer_kwargs"], **kwargs) + text_config = config.get('text_cfg', {}) + if 'tokenizer_kwargs' in text_config: + tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) else: tokenizer_kwargs = kwargs if context_length is None: - context_length = text_config.get("context_length", DEFAULT_CONTEXT_LENGTH) + context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) - if "hf_tokenizer_name" in text_config: + if 'hf_tokenizer_name' in text_config: if model_name.startswith("nllb"): tokenizer = NLLBTokenizer( - text_config["hf_tokenizer_name"], - context_length=context_length, - **tokenizer_kwargs, - ) + text_config['hf_tokenizer_name'], + context_length=context_length, + **tokenizer_kwargs, + ) else: tokenizer = HFTokenizer( - text_config["hf_tokenizer_name"], + text_config['hf_tokenizer_name'], context_length=context_length, **tokenizer_kwargs, ) @@ -158,39 +133,36 @@ def get_tokenizer( return tokenizer -def load_state_dict(checkpoint_path: str, map_location="cpu"): +def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) - if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] elif isinstance(checkpoint, torch.jit.ScriptModule): state_dict = checkpoint.state_dict() for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) else: state_dict = checkpoint - if next(iter(state_dict.items()))[0].startswith("module"): + if next(iter(state_dict.items()))[0].startswith('module'): state_dict = {k[7:]: v for k, v in state_dict.items()} return state_dict def load_checkpoint(model, checkpoint_path, strict=True): - if Path(checkpoint_path).suffix in (".npz", ".npy"): + if Path(checkpoint_path).suffix in ('.npz', '.npy'): from .big_vision import load_big_vision_weights - load_big_vision_weights(model, checkpoint_path) return {} state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format - if "positional_embedding" in state_dict and not hasattr( - model, "positional_embedding" - ): + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 - if "logit_bias" not in state_dict and model.logit_bias is not None: + if 'logit_bias' not in state_dict and model.logit_bias is not None: state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) # Certain text transformers no longer expect position_ids after transformers==4.31 - position_id_key = "text.transformer.embeddings.position_ids" + position_id_key = 'text.transformer.embeddings.position_ids' if position_id_key in state_dict and not hasattr(model, position_id_key): del state_dict[position_id_key] resize_pos_embed(state_dict, model) @@ -200,45 +172,43 @@ def load_checkpoint(model, checkpoint_path, strict=True): def create_model( - model_name: str, - pretrained: Optional[str] = None, - precision: str = "fp32", - device: Union[str, torch.device] = "cpu", - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_patch_dropout: Optional[float] = None, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - force_preprocess_cfg: Optional[Dict[str, Any]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, - cache_dir: Optional[str] = None, - output_dict: Optional[bool] = None, - require_pretrained: bool = False, - **model_kwargs, + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + force_preprocess_cfg: Optional[Dict[str, Any]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + **model_kwargs, ): force_preprocess_cfg = force_preprocess_cfg or {} preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: - model_id = model_name[len(HF_HUB_PREFIX) :] + model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) config = _get_hf_config(model_id, cache_dir) - preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config["preprocess_cfg"]) - model_cfg = config["model_cfg"] + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) + model_cfg = config['model_cfg'] pretrained_hf = False # override, no need to load original HF text weights else: - model_name = model_name.replace( - "/", "-" - ) # for callers using old naming with / in ViT names + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names checkpoint_path = None model_cfg = None if isinstance(device, str): device = torch.device(device) - if pretrained and pretrained.lower() == "openai": - logging.info(f"Loading pretrained {model_name} from OpenAI.") + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') model = load_openai_model( model_name, precision=precision, @@ -248,12 +218,10 @@ def create_model( else: model_cfg = model_cfg or get_model_config(model_name) if model_cfg is not None: - logging.info(f"Loaded {model_name} model config.") + logging.info(f'Loaded {model_name} model config.') else: - logging.error( - f"Model config for {model_name} not found; available models {list_models()}." - ) - raise RuntimeError(f"Model config for {model_name} not found.") + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models @@ -267,31 +235,23 @@ def create_model( # override model config's image size model_cfg["vision_cfg"]["image_size"] = force_image_size - is_timm_model = "timm_model_name" in model_cfg.get("vision_cfg", {}) + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: if is_timm_model: # pretrained weight loading for timm models set via vision_cfg - model_cfg["vision_cfg"]["timm_model_pretrained"] = True + model_cfg['vision_cfg']['timm_model_pretrained'] = True else: - assert ( - False - ), "pretrained image towers currently only supported for timm models" + assert False, 'pretrained image towers currently only supported for timm models' # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes cast_dtype = get_cast_dtype(precision) - is_hf_model = "hf_model_name" in model_cfg.get("text_cfg", {}) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) if is_hf_model: # load pretrained weights for HF text model IFF no CLIP weights being loaded - model_cfg["text_cfg"]["hf_model_pretrained"] = ( - pretrained_hf and not pretrained - ) - custom_text = ( - model_cfg.pop("custom_text", False) or force_custom_text or is_hf_model - ) + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model - model_cfg = dict( - model_cfg, **model_kwargs - ) # merge cfg dict w/ kwargs (kwargs overrides cfg) + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) if custom_text: if "multimodal_cfg" in model_cfg: model = CoCa(**model_cfg, cast_dtype=cast_dtype) @@ -301,7 +261,7 @@ def create_model( model = CLIP(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): - dtype = torch.float16 if "fp16" in precision else torch.bfloat16 + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 # manual mixed precision that matches original OpenAI behaviour if is_timm_model: # FIXME this is a bit janky, create timm based model in low-precision and @@ -314,52 +274,45 @@ def _convert_ln(m): if isinstance(m, LayerNormFp32): m.weight.data = m.weight.data.to(torch.float32) m.bias.data = m.bias.data.to(torch.float32) - model.apply(_convert_ln) else: model.to(device=device) convert_weights_to_lp(model, dtype=dtype) elif precision in ("pure_fp16", "pure_bf16"): - dtype = torch.float16 if "fp16" in precision else torch.bfloat16 + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 model.to(device=device, dtype=dtype) else: model.to(device=device) pretrained_loaded = False if pretrained: - checkpoint_path = "" + checkpoint_path = '' pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: - checkpoint_path = download_pretrained( - pretrained_cfg, cache_dir=cache_dir - ) + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) elif os.path.exists(pretrained): checkpoint_path = pretrained if checkpoint_path: - logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path) else: error_str = ( - f"Pretrained weights ({pretrained}) not found for model {model_name}." - f" Available pretrained tags ({list_pretrained_tags_by_model(model_name)}." - ) + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') logging.warning(error_str) raise RuntimeError(error_str) pretrained_loaded = True elif has_hf_hub_prefix: - logging.info( - f"Loading pretrained {model_name} weights ({checkpoint_path})." - ) + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') load_checkpoint(model, checkpoint_path) pretrained_loaded = True if require_pretrained and not pretrained_loaded: # callers of create_model_from_pretrained always expect pretrained weights raise RuntimeError( - f"Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded." - ) + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') if output_dict and hasattr(model, "output_dict"): model.output_dict = True @@ -368,12 +321,10 @@ def _convert_ln(m): model = torch.jit.script(model) # set image preprocessing configuration in model attributes for convenience - if getattr(model.visual, "image_size", None) is not None: + if getattr(model.visual, 'image_size', None) is not None: # use image_size set on model creation (via config or force_image_size arg) - force_preprocess_cfg["size"] = model.visual.image_size - set_model_preprocess_cfg( - model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg) - ) + force_preprocess_cfg['size'] = model.visual.image_size + set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) return model @@ -416,33 +367,28 @@ def create_loss(args): def create_model_and_transforms( - model_name: str, - pretrained: Optional[str] = None, - precision: str = "fp32", - device: Union[str, torch.device] = "cpu", - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_patch_dropout: Optional[float] = None, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - image_interpolation: Optional[str] = None, - image_resize_mode: Optional[str] = None, # only effective for inference - aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, - cache_dir: Optional[str] = None, - output_dict: Optional[bool] = None, - **model_kwargs, + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, - mean=image_mean, - std=image_std, - interpolation=image_interpolation, - resize_mode=image_resize_mode, - ) + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) model = create_model( model_name, @@ -478,29 +424,24 @@ def create_model_and_transforms( def create_model_from_pretrained( - model_name: str, - pretrained: Optional[str] = None, - precision: str = "fp32", - device: Union[str, torch.device] = "cpu", - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - image_interpolation: Optional[str] = None, - image_resize_mode: Optional[str] = None, # only effective for inference - return_transform: bool = True, - cache_dir: Optional[str] = None, - **model_kwargs, + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + return_transform: bool = True, + cache_dir: Optional[str] = None, + **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, - mean=image_mean, - std=image_std, - interpolation=image_interpolation, - resize_mode=image_resize_mode, - ) + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) model = create_model( model_name, diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 54150da90..e743cc50c 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -6,26 +6,20 @@ import logging import math from dataclasses import dataclass -from functools import partial from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn +from torch.utils.checkpoint import checkpoint +from functools import partial from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel -from .transformer import ( - Attention, - LayerNorm, - LayerNormFp32, - QuickGELU, - TextTransformer, - VisionTransformer, - text_global_pool, -) +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ + text_global_pool from .utils import to_2tuple @@ -39,32 +33,24 @@ class CLIPVisionCfg: image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value - patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling no_ln_pre: bool = False # disable pre transformer LayerNorm - pos_embed_type: str = "learnable" + pos_embed_type: str = 'learnable' final_ln_after_pool: bool = False # apply final LayerNorm after pooling - pool_type: str = "tok" + pool_type: str = 'tok' output_tokens: bool = False act_kwargs: Optional[dict] = None norm_kwargs: Optional[dict] = None - timm_model_name: Optional[ - str - ] = None # a valid model name overrides layers, width, patch_size - timm_model_pretrained: bool = ( - False # use (imagenet) pretrained weights for named model - ) - timm_pool: str = ( - "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') - ) - timm_proj: str = ( - "linear" # linear projection for timm model output ('linear', 'mlp', '') - ) + timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection - timm_drop: float = 0.0 # head dropout + timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth @@ -84,7 +70,7 @@ class CLIPTextCfg: pad_id: int = 0 no_causal_mask: bool = False # disable causal masking final_ln_after_pool: bool = False # apply final LayerNorm after pooling - pool_type: str = "argmax" + pool_type: str = 'argmax' proj_bias: bool = False output_tokens: bool = False act_kwargs: dict = None @@ -93,33 +79,33 @@ class CLIPTextCfg: # HuggingFace specific text tower config hf_model_name: Optional[str] = None hf_model_pretrained: bool = True - hf_proj_type: str = "mlp" - hf_pooler_type: str = "mean_pooler" # attentional pooling for HF models + hf_proj_type: str = 'mlp' + hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models def get_cast_dtype(precision: str): cast_dtype = None - if precision == "bf16": + if precision == 'bf16': cast_dtype = torch.bfloat16 - elif precision == "fp16": + elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None - if precision in ("bf16", "pure_bf16"): + if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 - elif precision in ("fp16", "pure_fp16"): + elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( - embed_dim: int, - vision_cfg: CLIPVisionCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) @@ -138,9 +124,7 @@ def _build_vision_tower( proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, - patch_drop=vision_cfg.patch_dropout - if vision_cfg.patch_dropout > 0 - else None, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) @@ -155,11 +139,7 @@ def _build_vision_tower( ) else: vision_heads = vision_cfg.width // vision_cfg.head_width - norm_layer = ( - LayerNormFp32 - if cast_dtype in (torch.float16, torch.bfloat16) - else LayerNorm - ) + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: @@ -191,10 +171,10 @@ def _build_vision_tower( def _build_text_tower( - embed_dim: int, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) @@ -210,11 +190,7 @@ def _build_text_tower( ) else: act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = ( - LayerNormFp32 - if cast_dtype in (torch.float16, torch.bfloat16) - else LayerNorm - ) + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if text_cfg.norm_kwargs: norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) if text_cfg.act_kwargs is not None: @@ -245,15 +221,15 @@ class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - init_logit_scale: float = np.log(1 / 0.07), - init_logit_bias: Optional[float] = None, - cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, ): super().__init__() self.output_dict = output_dict @@ -269,7 +245,7 @@ def __init__( self.ln_final = text.ln_final self.text_projection = text.text_projection self.text_pool_type = text.pool_type - self.register_buffer("attn_mask", text.attn_mask, persistent=False) + self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: @@ -279,9 +255,7 @@ def __init__( def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - self.visual.lock( - unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats - ) + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -321,34 +295,25 @@ def get_logits(self, image, text): return image_logits, text_logits def forward( - self, - image: Optional[torch.Tensor] = None, - text: Optional[torch.Tensor] = None, + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, ): - image_features = ( - self.encode_image(image, normalize=True) if image is not None else None - ) - text_features = ( - self.encode_text(text, normalize=True) if text is not None else None - ) + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, - "logit_scale": self.logit_scale.exp(), + "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: - out_dict["logit_bias"] = self.logit_bias + out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: - return ( - image_features, - text_features, - self.logit_scale.exp(), - self.logit_bias, - ) + return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() @@ -356,15 +321,15 @@ class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - init_logit_scale: float = np.log(1 / 0.07), - init_logit_bias: Optional[float] = None, - cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, ): super().__init__() self.output_dict = output_dict @@ -380,9 +345,7 @@ def __init__( def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - self.visual.lock( - unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats - ) + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): self.text.lock(unlocked_layers, freeze_layer_norm) @@ -410,34 +373,25 @@ def get_logits(self, image, text): return image_logits, text_logits def forward( - self, - image: Optional[torch.Tensor] = None, - text: Optional[torch.Tensor] = None, + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, ): - image_features = ( - self.encode_image(image, normalize=True) if image is not None else None - ) - text_features = ( - self.encode_text(text, normalize=True) if text is not None else None - ) + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, - "logit_scale": self.logit_scale.exp(), + "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: - out_dict["logit_bias"] = self.logit_bias + out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: - return ( - image_features, - text_features, - self.logit_scale.exp(), - self.logit_bias, - ) + return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() @@ -451,12 +405,7 @@ def _convert_weights(l): l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): - for attr in [ - *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], - "in_proj_bias", - "bias_k", - "bias_v", - ]: + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.to(dtype) @@ -481,68 +430,45 @@ def _convert_weights(l): # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): - if "text_projection" in state_dict: + if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): - if any( - k.startswith(p) - for p in ( - "text_projection", - "positional_embedding", - "token_embedding", - "transformer", - "ln_final", - ) - ): - k = "text." + k + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( - state_dict: dict, - quick_gelu=True, - cast_dtype=torch.float16, + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( - [ - k - for k in state_dict.keys() - if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") - ] - ) + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round( - (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 - ) + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ - len( - set( - k.split(".")[2] - for k in state_dict - if k.startswith(f"visual.layer{b}") - ) - ) - for b in [1, 2, 3, 4] - ] + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round( - (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 - ) + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None - assert ( - output_width**2 + 1 - == state_dict["visual.attnpool.positional_embedding"].shape[0] - ) + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] @@ -550,13 +476,7 @@ def build_model_from_openai_state_dict( vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 - transformer_layers = len( - set( - k.split(".")[2] - for k in state_dict - if k.startswith(f"transformer.resblocks") - ) - ) + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, @@ -581,62 +501,46 @@ def build_model_from_openai_state_dict( for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) - convert_weights_to_fp16( - model - ) # OpenAI state dicts are partially converted to float16 + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() -def trace_model(model, batch_size=256, device=torch.device("cpu")): +def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) - example_text = torch.zeros( - (batch_size, model.context_length), dtype=torch.int, device=device - ) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), - encode_image=(example_images,), - ), - ) + encode_image=(example_images,) + )) model.visual.image_size = image_size return model -def resize_pos_embed( - state_dict, model, interpolation: str = "bicubic", antialias: bool = True -): +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict - old_pos_embed = state_dict.get("visual.positional_embedding", None) - if old_pos_embed is None or not hasattr(model.visual, "grid_size"): + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) - extra_tokens = ( - 1 # FIXME detect different token configs (ie no class token, or more) - ) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: - pos_emb_tok, pos_emb_img = ( - old_pos_embed[:extra_tokens], - old_pos_embed[extra_tokens:], - ) + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) - logging.info( - "Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size - ) - pos_emb_img = pos_emb_img.reshape( - 1, old_grid_size[0], old_grid_size[1], -1 - ).permute(0, 3, 1, 2) + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, @@ -644,38 +548,32 @@ def resize_pos_embed( antialias=antialias, align_corners=False, ) - pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape( - 1, grid_size[0] * grid_size[1], -1 - )[0] + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img - state_dict["visual.positional_embedding"] = new_pos_embed + state_dict['visual.positional_embedding'] = new_pos_embed -def resize_text_pos_embed( - state_dict, model, interpolation: str = "linear", antialias: bool = False -): - old_pos_embed = state_dict.get("positional_embedding", None) +def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): + old_pos_embed = state_dict.get('positional_embedding', None) if old_pos_embed is None: return # FIXME add support for text cls_token - model_pos_embed = getattr(model, "positional_embedding", None) + model_pos_embed = getattr(model, 'positional_embedding', None) if model_pos_embed is None: - model_pos_embed = getattr(model.text, "positional_embedding", None) + model_pos_embed = getattr(model.text, 'positional_embedding', None) old_num_pos = old_pos_embed.shape[0] old_width = old_pos_embed.shape[1] num_pos = model_pos_embed.shape[0] width = model_pos_embed.shape[1] - assert old_width == width, "text pos_embed width changed!" + assert old_width == width, 'text pos_embed width changed!' if old_num_pos == num_pos: return - logging.info( - "Resizing text position embedding num_pos from %s to %s", old_num_pos, num_pos - ) + logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) old_pos_embed = F.interpolate( old_pos_embed, @@ -687,44 +585,40 @@ def resize_text_pos_embed( old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] new_pos_embed = old_pos_embed - state_dict["positional_embedding"] = new_pos_embed + state_dict['positional_embedding'] = new_pos_embed def get_model_preprocess_cfg(model): - module = getattr(model, "visual", model) - preprocess_cfg = getattr(module, "preprocess_cfg", {}) + module = getattr(model, 'visual', model) + preprocess_cfg = getattr(module, 'preprocess_cfg', {}) if not preprocess_cfg: # use separate legacy attributes if preprocess_cfg dict not found - size = getattr(module, "image_size") + size = getattr(module, 'image_size') if size is not None: - preprocess_cfg["size"] = size - mean = getattr(module, "image_mean", None) + preprocess_cfg['size'] = size + mean = getattr(module, 'image_mean', None) if mean is not None: - preprocess_cfg["mean"] = mean - std = getattr(module, "image_std", None) + preprocess_cfg['mean'] = mean + std = getattr(module, 'image_std', None) if std is not None: - preprocess_cfg["std"] = std + preprocess_cfg['std'] = std return preprocess_cfg def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): - module = getattr(model, "visual", model) - module.image_mean = preprocess_cfg[ - "mean" - ] # legacy attribute, keeping for bwd compat - module.image_std = preprocess_cfg["std"] # legacy attribute, keeping for bwd compat - module.preprocess_cfg = copy.deepcopy( - preprocess_cfg - ) # new attr, package all pp cfg as dict + module = getattr(model, 'visual', model) + module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat + module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat + module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict def get_model_tokenize_cfg(model): - module = getattr(model, "text", model) + module = getattr(model, 'text', model) cfg = {} - context_length = getattr(module, "context_length", None) + context_length = getattr(module, 'context_length', None) if context_length is not None: - cfg["context_length"] = context_length - vocab_size = getattr(module, "vocab_size", None) + cfg['context_length'] = context_length + vocab_size = getattr(module, 'vocab_size', None) if vocab_size is not None: - cfg["vocab_size"] = vocab_size - return cfg + cfg['vocab_size'] = vocab_size + return cfg \ No newline at end of file diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 1518cd441..825a1431e 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -24,9 +24,7 @@ @lru_cache() def default_bpe(): - return os.path.join( - os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" - ) + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() @@ -40,17 +38,13 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) - cs.append(2**8 + n) + cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -75,7 +69,7 @@ def basic_clean(text): def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) + text = re.sub(r'\s+', ' ', text) text = text.strip() return text @@ -96,11 +90,11 @@ def _clean_whitespace(x): def get_clean_fn(type: str): - if type == "canonicalize": + if type == 'canonicalize': return _clean_canonicalize - elif type == "lower": + elif type == 'lower': return _clean_lower - elif type == "whitespace": + elif type == 'whitespace': return _clean_whitespace else: assert False, f"Invalid clean function ({type})." @@ -121,8 +115,7 @@ def canonicalize_text(text, *, keep_punctuation_exact_string=None): if keep_punctuation_exact_string: text = keep_punctuation_exact_string.join( part.translate(str.maketrans("", "", string.punctuation)) - for part in text.split(keep_punctuation_exact_string) - ) + for part in text.split(keep_punctuation_exact_string)) else: text = text.translate(str.maketrans("", "", string.punctuation)) text = text.lower() @@ -132,30 +125,30 @@ def canonicalize_text(text, *, keep_punctuation_exact_string=None): class SimpleTokenizer(object): def __init__( - self, - bpe_path: str = default_bpe(), - additional_special_tokens: Optional[List[str]] = None, - context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, - clean: str = "lower", - reduction_mask: str = "", + self, + bpe_path: str = default_bpe(), + additional_special_tokens: Optional[List[str]] = None, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'lower', + reduction_mask: str = '' ): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") - merges = merges[1 : 49152 - 256 - 2 + 1] + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v + "" for v in vocab] + vocab = vocab + [v+'' for v in vocab] for merge in merges: - vocab.append("".join(merge)) - special_tokens = ["", ""] + vocab.append(''.join(merge)) + special_tokens = ['', ''] if additional_special_tokens: special_tokens += additional_special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {t: t for t in special_tokens} + self.cache = {t:t for t in special_tokens} special = "|".join(special_tokens) self.pat = re.compile( special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", @@ -167,21 +160,19 @@ def __init__( self.eot_token_id = self.all_special_ids[1] self.context_length = context_length self.clean_fn = get_clean_fn(clean) - self.reduction_fn = ( - get_reduction_mask_fn(reduction_mask) if reduction_mask else None - ) + self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None def bpe(self, token): if token in self.cache: return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + "",) + word = tuple(token[:-1]) + ( token[-1] + '',) pairs = get_pairs(word) if not pairs: - return token + "" + return token+'' while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -196,8 +187,8 @@ def bpe(self, token): new_word.extend(word[i:]) break - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) i += 2 else: new_word.append(word[i]) @@ -208,7 +199,7 @@ def bpe(self, token): break else: pairs = get_pairs(word) - word = " ".join(word) + word = ' '.join(word) self.cache[token] = word return word @@ -216,25 +207,17 @@ def encode(self, text): bpe_tokens = [] text = self.clean_fn(text) for token in re.findall(self.pat, text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend( - self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") - ) + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = ( - bytearray([self.byte_decoder[c] for c in text]) - .decode("utf-8", errors="replace") - .replace("", " ") - ) + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text - def __call__( - self, texts: Union[str, List[str]], context_length: Optional[int] = None - ) -> torch.LongTensor: - """Returns the tokenized representation of given input string(s) + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: + """ Returns the tokenized representation of given input string(s) Parameters ---------- @@ -251,7 +234,7 @@ def __call__( texts = [texts] context_length = context_length or self.context_length - assert context_length, "Please set a valid context length" + assert context_length, 'Please set a valid context length' if self.reduction_fn is not None: # use reduction strategy for tokenize if set, otherwise default to truncation below @@ -263,17 +246,14 @@ def __call__( encode_fn=self.encode, ) - all_tokens = [ - [self.sot_token_id] + self.encode(text) + [self.eot_token_id] - for text in texts - ] + all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = self.eot_token_id - result[i, : len(tokens)] = torch.tensor(tokens) + result[i, :len(tokens)] = torch.tensor(tokens) return result @@ -286,19 +266,17 @@ def decode(output_ids: torch.Tensor): return _tokenizer.decode(output_ids) -def tokenize( - texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH -) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor: return _tokenizer(texts, context_length=context_length) def random_mask_tokenize( - texts: Union[str, List[str]], - context_length: int, - sot_token_id: int, - eot_token_id: int, - encode_fn: Callable, - shuffle: bool = False, + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, + shuffle: bool = False, ): all_tokens = [encode_fn(text) for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) @@ -315,18 +293,18 @@ def random_mask_tokenize( tokens = tokens[indices] num_tokens = num_keep result[i, 0] = sot_token_id - result[i, 1 : num_tokens + 1] = tokens + result[i, 1:num_tokens + 1] = tokens result[i, num_tokens + 1] = eot_token_id return result def simple_mask_tokenize( - texts: Union[str, List[str]], - context_length: int, - sot_token_id: int, - eot_token_id: int, - encode_fn: Callable, + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, ): all_tokens = [encode_fn(text) for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) @@ -336,38 +314,37 @@ def simple_mask_tokenize( if num_tokens > context_length - 2: # 2 for sot and eot token num_keep = context_length - 2 start_index = random.randint(0, num_tokens - num_keep) # high is incl - tokens = tokens[start_index : start_index + num_keep] + tokens = tokens[start_index: start_index + num_keep] tokens = [sot_token_id] + tokens + [eot_token_id] - result[i, : len(tokens)] = torch.tensor(tokens) + result[i, :len(tokens)] = torch.tensor(tokens) return result def syntax_mask_tokenize( - texts: Union[str, List[str]], - context_length: int, - sot_token_id: int, - eot_token_id: int, - encode_fn: Callable, + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, ) -> torch.LongTensor: - """Returns the tokenized representation of given input string(s). + """ Returns the tokenized representation of given input string(s). Apply syntax masking before tokenize. """ import nltk - global _nltk_init if not _nltk_init: # run them for the first time - nltk.download("punkt") - nltk.download("averaged_perceptron_tagger") + nltk.download('punkt') + nltk.download('averaged_perceptron_tagger') _nltk_init = True def get_order(x): - if x.startswith("NN"): + if x.startswith('NN'): return 1 - elif x.startswith("JJ"): + elif x.startswith('JJ'): return 2 - elif x.startswith("VB"): + elif x.startswith('VB'): return 3 else: return 4 @@ -380,16 +357,12 @@ def get_order(x): # sample the words by get_order method order_list = [get_order(tag) for _, tag in pos_tags] sorted_ids = np.argsort(np.array(order_list)) - sampled_ids = sorted( - sorted_ids[: context_length - 2] - ) # need 2 slots for sot and eot tokens - sampled_tokens = np.take( - np.array(list_tokens), sampled_ids, axis=0 - ) # sample the tokens - - new_text = "" + sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens + sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens + + new_text = '' for token in sampled_tokens: - new_text = new_text + str(token) + " " + new_text = new_text + str(token) + ' ' new_text = new_text.strip() new_texts.append(new_text) texts = new_texts @@ -402,23 +375,21 @@ def get_order(x): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token_id - result[i, : len(tokens)] = torch.tensor(tokens) + result[i, :len(tokens)] = torch.tensor(tokens) return result def get_reduction_mask_fn(type: str): - """Choose strategy for dropping (masking) tokens to achieve target context length""" - assert type in ("simple", "random", "shuffle", "syntax") - if type == "simple": + """ Choose strategy for dropping (masking) tokens to achieve target context length""" + assert type in ('simple', 'random', 'shuffle', 'syntax') + if type == 'simple': return simple_mask_tokenize # randomly select block [start:end] - elif type == "random": + elif type == 'random': return random_mask_tokenize # randomly drop tokens (keep order) - elif type == "shuffle": - return partial( - random_mask_tokenize, shuffle=True - ) # randomly drop tokens (shuffle order) - elif type == "syntax": + elif type == 'shuffle': + return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order) + elif type == 'syntax': return syntax_mask_tokenize # randomly drop prioritized by syntax @@ -426,14 +397,13 @@ class HFTokenizer: """HuggingFace tokenizer wrapper""" def __init__( - self, - tokenizer_name: str, - context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, - clean: str = "whitespace", - strip_sep_token: bool = False, + self, + tokenizer_name: str, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'whitespace', + strip_sep_token: bool = False, ): from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.context_length = context_length self.clean_fn = get_clean_fn(clean) @@ -442,25 +412,21 @@ def __init__( def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) - def __call__( - self, texts: Union[str, List[str]], context_length: Optional[int] = None - ) -> torch.Tensor: + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] context_length = context_length or self.context_length - assert ( - context_length - ), "Please set a valid context length in class init or call." + assert context_length, 'Please set a valid context length in class init or call.' texts = [self.clean_fn(text) for text in texts] input_ids = self.tokenizer.batch_encode_plus( texts, - return_tensors="pt", + return_tensors='pt', max_length=context_length, - padding="max_length", + padding='max_length', truncation=True, ).input_ids @@ -475,8 +441,8 @@ def __call__( class SigLipTokenizer: - """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs""" - + """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs + """ VOCAB_FILES = { # english, vocab_size=32_000 "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model", @@ -485,9 +451,9 @@ class SigLipTokenizer: } def __init__( - self, - tokenizer_name: str, - context_length: Optional[int] = 64, + self, + tokenizer_name: str, + context_length: Optional[int] = 64, ): from transformers import T5TokenizerFast @@ -496,10 +462,9 @@ def __init__( import tempfile import fsspec - vocab_file = self.VOCAB_FILES[tokenizer_name] - with tempfile.NamedTemporaryFile("wb") as dst: - with fsspec.open(vocab_file, "rb") as src: + with tempfile.NamedTemporaryFile('wb') as dst: + with fsspec.open(vocab_file, 'rb') as src: dst.write(src.read()) self.tokenizer = T5TokenizerFast(dst.name, legacy=False) else: @@ -512,25 +477,21 @@ def __init__( def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) - def __call__( - self, texts: Union[str, List[str]], context_length: Optional[int] = None - ) -> torch.Tensor: + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] context_length = context_length or self.context_length - assert ( - context_length - ), "Please set a valid context length in class init or call." + assert context_length, 'Please set a valid context length in class init or call.' texts = [canonicalize_text(basic_clean(text)) for text in texts] output = self.tokenizer( texts, - return_tensors="pt", + return_tensors='pt', max_length=context_length, - padding="max_length", + padding='max_length', truncation=True, ) return output.input_ids