Skip to content

Commit

Permalink
Merge pull request mlpc-ucsd#4 from hanajibsa/stage2-t5-2
Browse files Browse the repository at this point in the history
Chore: eval accuracy
  • Loading branch information
hanajibsa authored Jul 22, 2024
2 parents 3d17986 + 5634f32 commit 465b0d2
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 75 deletions.
8 changes: 4 additions & 4 deletions daiv/common/vqa_tools/vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def loadRes(self, resFile, quesFile):
anns = json.load(open(resFile))
assert type(anns) == list, "results is not an array of objects"
annsQuesIds = [ann["question_id"] for ann in anns]
assert set(annsQuesIds) == set(
self.getQuesIds()
), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
# assert set(annsQuesIds) == set(
# self.getQuesIds()
# ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
for ann in anns:
quesId = ann["question_id"]
if res.dataset["task_type"] == "Multiple Choice":
Expand All @@ -208,4 +208,4 @@ def loadRes(self, resFile, quesFile):

res.dataset["annotations"] = anns
res.createIndex()
return res
return res, annsQuesIds
32 changes: 16 additions & 16 deletions daiv/configs/datasets/okvqa/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,21 @@ datasets:
- /root/workspace/24s-VQA-MLLM/BEiT3/stage2-t5/VQA-MLLM-stage2/daiv/data/okvqa/okvqa_train.json
# - okvqa/annotations/OpenEnded_mscoco_train2014_questions.json
# - okvqa/annotations/mscoco_train2014_annotations.json
# test:
# url:
# # TODO make this order insensitive
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
# storage:
# # - okvqa/annotations/vqa_val_eval.json
# # - okvqa/annotations/answer_list.json
# # - okvqa/annotations/OpenEnded_mscoco_val2014_questions.json
# # - okvqa/annotations/mscoco_val2014_annotations.json
# - /root/workspace/24s-VQA-MLLM/BEiT3/stage2-eval/VQA-MLLM-stage2/daiv/data/okvqa/okvqa_val.json
# - /root/datasets/okvqa/data/assets/answer_dict_okvqa.json
# - /root/datasets/okvqa/data/okvqa/OpenEnded_mscoco_val2014_questions.json
# - /root/datasets/okvqa/data/okvqa/mscoco_val2014_annotations.json
test:
url:
# TODO make this order insensitive
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage:
# - okvqa/annotations/vqa_val_eval.json
# - okvqa/annotations/answer_list.json
# - okvqa/annotations/OpenEnded_mscoco_val2014_questions.json
# - okvqa/annotations/mscoco_val2014_annotations.json
- /root/workspace/24s-VQA-MLLM/BEiT3/stage2-t5/VQA-MLLM-stage2/daiv/data/okvqa/okvqa_val.json #testing
- /root/datasets/okvqa/data/assets/answer_dict_okvqa.json
- /root/datasets/okvqa/data/okvqa/OpenEnded_mscoco_val2014_questions.json
- /root/datasets/okvqa/data/okvqa/mscoco_val2014_annotations.json
images:
storage: /root/datasets/okvqa/data
2 changes: 1 addition & 1 deletion daiv/configs/models/blip2_instruct_flant5xl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
load_finetuned: False
#load_pretrained: True

pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_flanxl_trimmed.pth"
# pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_flanxl_trimmed.pth"
finetuned: ""

# vit encoder
Expand Down
1 change: 1 addition & 0 deletions daiv/data/okvqa/okvqa_val_one.json

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions daiv/models/blip2_t5_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward(self, samples):
image_atts_mcan = self.MCAN.make_mask(image_embeds_mcan).to(image.device)

text_input_mcan = samples["text_input"]
# text_input_llm = samples["text_input"]
text_input_llm = samples["text_input"]

# Process text for MCAN
text_tokens_mcan = self.tokenizer(
Expand Down Expand Up @@ -148,16 +148,16 @@ def forward(self, samples):
atts_llm_mcan = torch.ones(text_embeds_llm_mcan.size()[:-1], dtype=torch.long).to(image.device)

# Process text for LLM
# text_tokens_llm = self.tokenizer(
# text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
# ).input_ids.to(image.device)
# text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
# atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)
text_tokens_llm = self.tokenizer(
text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
).input_ids.to(image.device)
text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)

# inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan, text_embeds_llm], dim=1)
# atts_llm = torch.cat([image_atts_llm, atts_llm_mcan, atts_llm_text], dim=1)
inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan], dim=1)
inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan, text_embeds_llm], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan, atts_llm_text], dim=1)
# inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan], dim=1)
# atts_llm = torch.cat([image_atts_llm, atts_llm_mcan], dim=1)

text_output = [t + self.t5_tokenizer.eos_token for t in samples["text_output"]]

Expand Down Expand Up @@ -214,7 +214,7 @@ def generate(
image_atts_mcan = self.MCAN.make_mask(image_embeds_mcan).to(image.device)

text_input_mcan = samples["text_input"]
# text_input_llm = samples["text_input"]
text_input_llm = samples["text_input"]

# Process text for MCAN
text_tokens_mcan = self.tokenizer(
Expand All @@ -241,16 +241,16 @@ def generate(
atts_llm_mcan = torch.ones(text_embeds_llm_mcan.size()[:-1], dtype=torch.long).to(image.device)

# Process text for LLM
# text_tokens_llm = self.tokenizer(
# text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
# ).input_ids.to(image.device)
# text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
# atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)

# inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm, text_embeds_llm_mcan], dim=1)
# atts_llm = torch.cat([image_atts_llm, atts_llm_text, atts_llm_mcan], dim=1)
inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan], dim=1)
text_tokens_llm = self.tokenizer(
text_input_llm, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len
).input_ids.to(image.device)
text_embeds_llm = self.text_embed_proj(self.MCAN.embedding(text_tokens_llm))
atts_llm_text = torch.ones(text_embeds_llm.size()[:-1], dtype=torch.long).to(image.device)

inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan, text_embeds_llm], dim=1)
atts_llm = torch.cat([image_atts_llm, atts_llm_mcan, atts_llm_text], dim=1)
# inputs_llm = torch.cat([image_embeds_llm, text_embeds_llm_mcan], dim=1)
# atts_llm = torch.cat([image_atts_llm, atts_llm_mcan], dim=1)

if "prompt" in samples.keys():
prompt = samples["prompt"]
Expand Down
42 changes: 21 additions & 21 deletions daiv/runners/runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,27 +384,27 @@ def train(self):
self.log_stats(split_name="train", stats=train_stats)

# evaluation phase
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))

val_log = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."

agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric and split_name == "val":
best_epoch, best_agg_metric = cur_epoch, agg_metrics

self._save_checkpoint(cur_epoch, is_best=True)

val_log.update({"best_epoch": best_epoch})
self.log_stats(val_log, split_name)
# if len(self.valid_splits) > 0:
# for split_name in self.valid_splits:
# logging.info("Evaluating on {}.".format(split_name))

# val_log = self.eval_epoch(
# split_name=split_name, cur_epoch=cur_epoch
# )
# if val_log is not None:
# if is_main_process():
# assert (
# "agg_metrics" in val_log
# ), "No agg_metrics found in validation log."

# agg_metrics = val_log["agg_metrics"]
# if agg_metrics > best_agg_metric and split_name == "val":
# best_epoch, best_agg_metric = cur_epoch, agg_metrics

# self._save_checkpoint(cur_epoch, is_best=True)

# val_log.update({"best_epoch": best_epoch})
# self.log_stats(val_log, split_name)

else:
# if no validation split is provided, we just save the checkpoint at the end of each epoch.
Expand Down
22 changes: 12 additions & 10 deletions daiv/tasks/vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
sample_id_key = "",
ques_files=dict(),
anno_files=dict(),
valid_splits=['val']
# valid_splits=['val']
valid_splits=['test']
):
super().__init__()

Expand Down Expand Up @@ -75,7 +76,8 @@ def setup_task(cls, cfg):
sample_id_key = run_cfg.get("sample_id_key", "instance_id")
ques_files = run_cfg.get("ques_files", dict())
anno_files = run_cfg.get("anno_files", dict())
valid_splits = run_cfg.get("valid_splits", ["val"])
# valid_splits = run_cfg.get("valid_splits", ["val"])
valid_splits = run_cfg.get("valid_splits", ["test"])


return cls(
Expand All @@ -99,7 +101,7 @@ def build_datasets(self, cfg):
for split in self.valid_splits:
if split not in dataset:
print(f"Split {split} not found in {ds_name}.")
continue # 추가
# continue # 추가
if (
hasattr(dataset[split], "coco_fmt_qust_file")
and dataset[split].coco_fmt_qust_file is not None
Expand Down Expand Up @@ -141,14 +143,14 @@ def valid_step(self, model, samples):
prompt=self.prompt,
)
pred_qa_pairs = []

question = samples['text_input']
question_id = samples["question_id"]
for answer, ques_id in zip(answers, question_id):
for answer, ques_id, que in zip(answers, question_id, question):
ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id
if ques_id != int and is_convertible_to_int(ques_id):
ques_id = int(ques_id)
pred_qa_pairs.append({"question_id": ques_id, "answer": answer})
print(f'answer : {answer}')
print(f'question: {que} / answer : {answer}')

return pred_qa_pairs

Expand All @@ -161,7 +163,7 @@ def after_evaluation(self, val_result, split_name, **kwargs):
)

metrics = self._report_metrics(result_file=result_file, split=split_name)

print(metrics)
return metrics

@dist_utils.main_process
Expand All @@ -170,17 +172,17 @@ def _report_metrics(self, result_file, split):
Use official VQA evaluation script to report metrics.
"""
metrics = {}

print(f'ques_files: {self.ques_files} / anno_files: {self.anno_files}' )
if split in self.ques_files and split in self.anno_files:
vqa = VQA(self.anno_files[split], self.ques_files[split])
vqa_result = vqa.loadRes(
vqa_result, resQuesIds = vqa.loadRes(
resFile=result_file, quesFile=self.ques_files[split]
)
# create vqaEval object by taking vqa and vqaRes
# n is precision of accuracy (number of places after decimal), default is 2
vqa_scorer = VQAEval(vqa, vqa_result, n=2)
logging.info("Start VQA evaluation.")
vqa_scorer.evaluate()
vqa_scorer.evaluate(resQuesIds)

# print accuracies
overall_acc = vqa_scorer.accuracy["overall"]
Expand Down
5 changes: 3 additions & 2 deletions train_configs/pretrain_stage2_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ run:

amp: True
# MCAN
resume_ckpt_path: '/root/workspace/24s-VQA-MLLM/EunJuPark/VQA-MLLM-stage2/daiv/output/BLIP2/Pretrain_stage2_eval/20240720153/checkpoint_9.pth'
resume_ckpt_path: '/root/workspace/24s-VQA-MLLM/EunJuPark/stage2/BLIVA/daiv/output/BLIP2/Pretrain_stage2/20240720153/checkpoint_9.pth'
# resume_ckpt_path: '/root/workspace/24s-VQA-MLLM/EunJuPark/VQA-MLLM-stage2/daiv/output/BLIP2/Pretrain_stage2_eval/20240720153/checkpoint_9.pth'
# resume_ckpt_path: '/root/workspace/24s-VQA-MLLM/BEiT3/VQA-MLLM-stage2/daiv/output/BLIP2/Pretrain_stage2/20240719160/checkpoint_9.pth'

evaluate: True
train_splits: ["train"]
# train_splits: ["train"]
test_splits: ["test"]

device: "cuda"
Expand Down

0 comments on commit 465b0d2

Please sign in to comment.