From f3c1af5f5f979cce6e65a668d1d6fa15372f2f30 Mon Sep 17 00:00:00 2001 From: Nahid Alam Date: Sun, 14 Jul 2024 21:57:04 +0000 Subject: [PATCH 1/2] siglip --- llava/model/multimodal_encoder/builder.py | 4 +- .../multimodal_encoder/siglip_encoder.py | 87 +++++++++++++++++++ scripts/v1_5/pretrain_siglip.sh | 35 ++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 llava/model/multimodal_encoder/siglip_encoder.py create mode 100644 scripts/v1_5/pretrain_siglip.sh diff --git a/llava/model/multimodal_encoder/builder.py b/llava/model/multimodal_encoder/builder.py index 29f63a26d..afad6ebc1 100644 --- a/llava/model/multimodal_encoder/builder.py +++ b/llava/model/multimodal_encoder/builder.py @@ -1,6 +1,6 @@ import os from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 - +from .siglip_encoder import SiglipVisionTower def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) @@ -11,5 +11,7 @@ def build_vision_tower(vision_tower_cfg, **kwargs): return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) else: return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif 'siglip' in vision_tower: + return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f'Unknown vision tower: {vision_tower}') diff --git a/llava/model/multimodal_encoder/siglip_encoder.py b/llava/model/multimodal_encoder/siglip_encoder.py new file mode 100644 index 000000000..ac0b827fb --- /dev/null +++ b/llava/model/multimodal_encoder/siglip_encoder.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig + +class SiglipVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + if not delay_load: + self.load_model() + elif getattr(args, 'unfreeze_mm_vision_tower', False): + self.load_model() + else: + self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs.hidden_states[self.select_layer] + if self.select_feature == 'patch': + image_features = image_features[:, 1:] + elif self.select_feature == 'cls_patch': + image_features = image_features + else: + raise ValueError(f'Unexpected select feature: {self.select_feature}') + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 diff --git a/scripts/v1_5/pretrain_siglip.sh b/scripts/v1_5/pretrain_siglip.sh new file mode 100644 index 000000000..49183ce1a --- /dev/null +++ b/scripts/v1_5/pretrain_siglip.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --version plain \ + --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ + --image_folder ./playground/data/LLaVA-Pretrain/images \ + --vision_tower google/siglip-so400m-patch14-384 \ + --mm_projector_type mlp2x_gelu \ + --tune_mm_mlp_adapter True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb From 1d5301e4221a153d98845c4689547f834d2a644f Mon Sep 17 00:00:00 2001 From: nahalam Date: Tue, 17 Sep 2024 15:22:12 -0700 Subject: [PATCH 2/2] crop_size and interpolate_pos_encoding for siglip --- llava/train/train.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..aa4f11b64 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -37,6 +37,8 @@ from PIL import Image +import functools + local_rank = None @@ -46,6 +48,21 @@ def rank0_print(*args): print(*args) +''' +This function sets interpolate_pos_encoding to True. If the image size is different than the default siglip +image size then positional embeddings are interpolated to account for the new size. +''' +def wrap_siglip_forward_method(siglip_object): + original_forward = siglip_object.forward + + @functools.wraps(original_forward) + def wrapped_forward(pixel_values, interpolate_pos_encoding=True): + return original_forward(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + siglip_object.forward = wrapped_forward + return siglip_object + + from packaging import version IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') @@ -734,7 +751,10 @@ def expand2square(pil_img, background_color): data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal - crop_size = self.data_args.image_processor.crop_size + if 'siglip' in self.data_args.image_processor.image_processor_type.lower(): + crop_size = self.data_args.image_processor.size + else: + crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict @@ -916,6 +936,12 @@ def make_inputs_require_grad(module, input, output): vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + if vision_tower.__class__.__name__ == 'SiglipVisionTower': + #Enforcing interpolate_pos_encoding = True by default for Siglip embeddings + siglip_embedding = vision_tower.vision_tower.vision_model.embeddings + siglip_embedding = wrap_siglip_forward_method(siglip_embedding) + vision_tower.vision_tower.vision_model.embeddings = siglip_embedding + data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True