Skip to content

Commit

Permalink
Update modeling_otter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian authored Dec 23, 2023
1 parent 50cacb9 commit e5c67b2
Showing 1 changed file with 1 addition and 285 deletions.
286 changes: 1 addition & 285 deletions src/otter_ai/models/otter/modeling_otter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import builtins
import random
import sys
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from accelerate import Accelerator
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
from einops import rearrange, repeat
from peft import LoraConfig, TaskType, get_peft_model
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer

from pipeline.utils.modeling_value_head import AutoModelForCausalLMWithValueHead

from transformers.models.auto import AutoTokenizer
from ..falcon.modelling_RW import RWForCausalLM
from ..mpt.modeling_mpt import MPTForCausalLM
from ..mpt_redpajama.mosaic_gpt import MosaicGPT
Expand Down Expand Up @@ -1045,282 +1040,3 @@ def generate(

self.lang_encoder.clear_conditioned_layers()
return output


class OtterForConditionalGenerationWithValueHead(OtterPreTrainedModel):
config_class = OtterConfig

def __init__(
self,
config: OtterConfig,
):
super().__init__(config)
### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
if "llama" not in config.text_config._name_or_path:
if config.text_config.architectures[0] == "MPTForCausalLM":
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
lang_encoder = MPTForCausalLM(config=config.text_config)
elif config.text_config.architectures[0] == "MosaicGPT":
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
lang_encoder = MosaicGPT(config=config.text_config)
elif config.text_config.architectures[0] == "RWForCausalLM":
text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
lang_encoder = RWForCausalLM(config=config.text_config)
elif config.text_config.architectures[0] == "LlamaForCausalLM":
text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
lang_encoder = LlamaForCausalLM(config=config.text_config)
else:
text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
lang_encoder = LlamaForCausalLM(config=config.text_config)
vision_encoder = CLIPVisionModel(config=config.vision_config)

text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]})
if text_tokenizer.pad_token is None:
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
self.text_tokenizer = text_tokenizer
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
self.media_token_id = text_tokenizer.encode("<image>")[-1]
self.lang_encoder_with_vhead = AutoModelForCausalLMWithValueHead(lang_encoder)
extend_instance(self.lang_encoder_with_vhead.pretrained_model, OtterLMMixin)
decoder_layers_attr_name = _infer_decoder_layers_attr_name(self.lang_encoder_with_vhead.pretrained_model)
self.lang_encoder_with_vhead.pretrained_model.set_decoder_layers_attr_name(decoder_layers_attr_name)
if self.lang_encoder_with_vhead.pretrained_model.__class__.__name__ == "LlamaForCausalLM":
self.lang_encoder_with_vhead.pretrained_model.resize_token_embeddings(len(text_tokenizer))

self.cross_attn_every_n_layers = config.cross_attn_every_n_layers
# use_media_placement_augmentation is strictly false for Otter model
self.use_media_placement_augmentation = False # config.use_media_placement_augmentation
self.max_num_frames = config.max_num_frames if hasattr(config, "max_num_frames") else None

# Informative master_print statement
if self.max_num_frames is None or self.max_num_frames == 1:
master_print(f"The current model version is configured for Otter-Image with max_num_frames set to {self.max_num_frames}.")
else:
master_print(f"The current model version is configured for Otter-Video with a maximum of {self.max_num_frames} frames.")

vision_encoder.output_tokens = True
self.vision_encoder = vision_encoder

self.vis_dim = 1024
self.perceiver = OtterPerceiverResampler(dim=self.vis_dim, max_num_frames=self.max_num_frames)

self.lang_encoder_with_vhead.pretrained_model.init_otter(
media_token_id=self.media_token_id,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
use_media_placement_augmentation=self.use_media_placement_augmentation,
)

if "lora_config" in config.__dict__:
original_architecture_name = self.lang_encoder_with_vhead.pretrained_model.__class__.__name__
master_print(f"Using LoRA with config:{config.lora_config}")
standard_modules = ["q_proj", "v_proj"]
lang_encoder_short_name = MODEL_CLASSES[config.text_config.architectures[0]]
model_to_lora_modules = {
"llama": standard_modules,
"opt": standard_modules,
"gptj": standard_modules,
"gpt_neox": ["query_key_value"],
"mpt": ["Wqkv"],
}
lora_config = LoraConfig(
r=config.lora_config["r"],
lora_alpha=config.lora_config["lora_alpha"],
lora_dropout=config.lora_config["lora_dropout"],
task_type=TaskType.CAUSAL_LM,
target_modules=model_to_lora_modules[lang_encoder_short_name],
)
self.lang_encoder_with_vhead.pretrained_model = get_peft_model(self.lang_encoder_with_vhead.pretrained_model, lora_config)
self.lang_encoder_with_vhead.pretrained_model.master_print_trainable_parameters()
self.lang_encoder_with_vhead.pretrained_model.__class__.__name__ = f"{original_architecture_name}LoRA"

self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.lang_encoder.get_input_embeddings()

def set_input_embeddings(self, new_embeddings):
self.lang_encoder.set_input_embeddings(new_embeddings)

def get_output_embeddings(self) -> nn.Module:
return self.lang_encoder.get_output_embeddings()

def set_output_embeddings(self, new_embeddings):
self.lang_encoder.set_output_embeddings(new_embeddings)

def get_image_encoder(self) -> nn.Module:
return self.vision_encoder

def get_lang_encoder(self) -> nn.Module:
return self.lang_encoder

def init_weights(self):
# Freeze all parameters in self.model
for param in self.parameters():
param.requires_grad = False

# Freeze all parameters in vision encoder
if "train_vision_encoder" in self.config.__dict__ and self.config.train_vision_encoder is True:
for param in self.vision_encoder.parameters():
param.requires_grad = True

# Freeze all parameters in lang encoders except gated_cross_attn_layers
if "train_lang_encoder" in self.config.__dict__ and self.config.train_lang_encoder is True:
for name, param in self.lang_encoder_with_vhead.named_parameters():
param.requires_grad = True

# Freeze all parameters in lang encoders except gated_cross_attn_layers
if "train_connector" in self.config.__dict__ and self.config.train_connector is True:
for (
name,
param,
) in self.lang_encoder_with_vhead.pretrained_model.named_parameters():
if "gated_cross_attn_layer" in name:
param.requires_grad = True
for name, param in self.named_parameters():
if "perceiver" in name:
param.requires_grad = True

if "lora_config" in self.config.__dict__:
# Use another logic to unfreeze gated_cross_attn_layers and perceivers
master_print(f"LoRA trainable param: {(sum(p.numel() for p in self.lang_encoder_with_vhead.pretrained_model.parameters() if p.requires_grad)) / 1e9:.3f} B")

# Unfreeze LM input and output embeddings
self.lang_encoder_with_vhead.pretrained_model.get_input_embeddings().requires_grad_(True)
## MPTForCausalLM is tied word embedding
if "LlamaForCausalLM" in self.lang_encoder_with_vhead.__class__.__name__:
self.lang_encoder_with_vhead.lm_head.requires_grad_(True)
# master_print("====================Model Grad Part====================")
total_params = 0
for name, param in self.named_parameters():
if param.requires_grad:
total_params += param.numel()
master_print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
master_print(f"Total Trainable param: {total_params / 1e9:.6f} B")

def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cached_vision_x: bool = False,
clear_conditioned_layers: bool = True,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Forward pass of Otter.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."

if use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
assert self.lang_encoder_with_vhead.is_conditioned()

else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x)

output = self.lang_encoder_with_vhead(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
**kwargs,
)

if clear_conditioned_layers:
self.lang_encoder_with_vhead.clear_conditioned_layers()

return output

def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""

assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]

vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)

vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)

for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)

@torch.no_grad()
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**generate_kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
if hasattr(self, "_hf_hook"):
# add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
hook = AlignDevicesHook(
execution_device=lang_x.device,
io_same_device=True,
place_submodules=False,
)
add_hook_to_module(self.lang_encoder, hook)
num_beams = generate_kwargs.get("num_beams", 1)
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
self._encode_vision_x(vision_x=vision_x)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=self.eoc_token_id,
**generate_kwargs,
)

self.lang_encoder.clear_conditioned_layers()
return output

0 comments on commit e5c67b2

Please sign in to comment.