Skip to content

This is an official implementation of GRIT-VLP

Notifications You must be signed in to change notification settings

tabtoyou/GRIT-VLP

 
 

Repository files navigation

GRIT-VLP: GRouped mIni-baTch sampling for Efficient Vision-Language Pre-training

This is the official PyTorch implementation of "GRIT-VLP: GRouped mIni-baTch sampling for Efficient Vision-Language Pre-training" (Accepted to ECCV 2022)

You can find the implementation codes for pre-training and fine-tuning GRIT-VLP.

Pre-training Dataset Download:

Downstream-task Datasets:

Json Files:

  • Use same json files from ALBEF
  • Change the image path in json files according to your downloaded images (In CC3M and SBU, some images can not be crawled, thus, you should consider about these missing images when creating json files)

Requirements:

  • pytorch 1.8.0
  • transformers 4.8.1
  • timm 0.4.9

Pre-training:

  1. Pre-train the model using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain/  

Downstream tasks:

  1. IRTR (MS-COCO) using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Retrieval.py --config ./configs/Retrieval_coco.yaml --output_dir output/Retrieval_coco/  --checkpoint [Pretrained checkpoint] 
  1. IRTR (Flickr) using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir output/Retrieval_coco/  --checkpoint [Pretrained checkpoint] 
  1. NLVR using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Pretrain_nlvr.py --config ./configs/NLVR_pretrain.yaml --output_dir output/NLVR_pretrain/ --checkpoint [Pretrained checkpoint] 
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env NLVR.py --config ./configs/NLVR.yaml --output_dir output/NLVR/ --checkpoint [NLVR-Pretrained checkpoint] 
  1. VQA using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env VQA.py --config ./configs/VQA.yaml --output_dir output/vqa/ --checkpoint [Pretrained checkpoint] 

If you have any questions or problems to run this code, please mail to [email protected] or [email protected]. Thank you!

Acknowledgement:

Our code implementation is largely borrowed from ALBEF since our method is mainly built upon it. We appreciate the original authors for sharing code.

About

This is an official implementation of GRIT-VLP

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%