Skip to content

Commit

Permalink
Chore: stage2-t5 2 input embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
pej0918 committed Jul 22, 2024
1 parent 184aad1 commit 3d17986
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
41 changes: 22 additions & 19 deletions daiv/models/blip2_t5_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward(self, samples):
image_atts_mcan = self.MCAN.make_mask(image_embeds_mcan).to(image.device)

text_input_mcan = samples["text_input"]
text_input_llm = samples["text_input"]
# text_input_llm = samples["text_input"]

# Process text for MCAN
text_tokens_mcan = self.tokenizer(
Expand Down Expand Up @@ -148,14 +148,16 @@ def forward(self, samples):
atts_llm_mcan = torch.ones(text_embeds_llm_mcan.size()[:-1], dtype=torch.long).to(image.device)

# Process text for LLM
text_tokens_llm = self.tokenizer(
text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
).input_ids.to(image.device)
text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)
# text_tokens_llm = self.tokenizer(
# text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
# ).input_ids.to(image.device)
# text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
# atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)

inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan, text_embeds_llm], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan, atts_llm_text], dim=1)
# inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan, text_embeds_llm], dim=1)
# atts_llm = torch.cat([image_atts_llm, atts_llm_mcan, atts_llm_text], dim=1)
inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan], dim=1)

text_output = [t + self.t5_tokenizer.eos_token for t in samples["text_output"]]

Expand Down Expand Up @@ -212,7 +214,7 @@ def generate(
image_atts_mcan = self.MCAN.make_mask(image_embeds_mcan).to(image.device)

text_input_mcan = samples["text_input"]
text_input_llm = samples["text_input"]
# text_input_llm = samples["text_input"]

# Process text for MCAN
text_tokens_mcan = self.tokenizer(
Expand All @@ -239,27 +241,28 @@ def generate(
atts_llm_mcan = torch.ones(text_embeds_llm_mcan.size()[:-1], dtype=torch.long).to(image.device)

# Process text for LLM
text_tokens_llm = self.tokenizer(
text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
).input_ids.to(image.device)
text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)
# text_tokens_llm = self.tokenizer(
# text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
# ).input_ids.to(image.device)
# text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
# atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)

inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm, text_embeds_llm_mcan], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_text, atts_llm_mcan], dim=1)
# inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm, text_embeds_llm_mcan], dim=1)
# atts_llm = torch.cat([image_atts_llm, atts_llm_text, atts_llm_mcan], dim=1)
inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan], dim=1)

if "prompt" in samples.keys():
prompt = samples["prompt"]
else:
prompt = self.prompt
# prompt = self.prompt
prompt = samples["text_input"]

# prompt = [prompt] * image.size(0)
if isinstance(prompt, str):
prompt = [prompt] * image.size(0)
else:
assert len(prompt) == image.size(0), "The number of prompts must be equal to the batch size."



t5_tokens = self.t5_tokenizer(
prompt,
Expand Down
6 changes: 3 additions & 3 deletions train_configs/pretrain_stage2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause

model:
arch: blip2_vicuna_instruct
model_type: vicuna7b
arch: blip2_t5_instruct
model_type: flant5xl
load_pretrained: True
# intialize stage 2 pretraining from stage 1 pretrained model
pretrained: "/root/workspace/24s-VQA-MLLM/EunJuPark/BLIVA/daiv/output/BLIP2/Pretrain_stage1/20240718133/checkpoint_9.pth" #"/home/intern24/daiv/BLIVA/daiv/output/BLIP2/Pretrain_stage1/mcan_stage1_checkpoint_9.pth"
pretrained: "/root/workspace/24s-VQA-MLLM/EunJuPark/mcan_stage1_checkpoint_9.pth" #"/home/intern24/daiv/BLIVA/daiv/output/BLIP2/Pretrain_stage1/mcan_stage1_checkpoint_9.pth"
freeze_vit: True


Expand Down

0 comments on commit 3d17986

Please sign in to comment.