This is the official PyTorch implementation of "GRIT-VLP: GRouped mIni-baTch sampling for Efficient Vision-Language Pre-training" (Accepted to ECCV 2022)
You can find the implementation codes for pre-training and fine-tuning GRIT-VLP.
- Use same json files from ALBEF
- Change the image path in json files according to your downloaded images (In CC3M and SBU, some images can not be crawled, thus, you should consider about these missing images when creating json files)
- pytorch 1.8.0
- transformers 4.8.1
- timm 0.4.9
- Pre-train the model using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain/
- IRTR (MS-COCO) using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Retrieval.py --config ./configs/Retrieval_coco.yaml --output_dir output/Retrieval_coco/ --checkpoint [Pretrained checkpoint]
- IRTR (Flickr) using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir output/Retrieval_coco/ --checkpoint [Pretrained checkpoint]
- NLVR using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Pretrain_nlvr.py --config ./configs/NLVR_pretrain.yaml --output_dir output/NLVR_pretrain/ --checkpoint [Pretrained checkpoint] python3 -m torch.distributed.launch --nproc_per_node=4 --use_env NLVR.py --config ./configs/NLVR.yaml --output_dir output/NLVR/ --checkpoint [NLVR-Pretrained checkpoint]
- VQA using 4 A100 GPUs:
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env VQA.py --config ./configs/VQA.yaml --output_dir output/vqa/ --checkpoint [Pretrained checkpoint]
If you have any questions or problems to run this code, please mail to [email protected] or [email protected]. Thank you!
Our code implementation is largely borrowed from ALBEF since our method is mainly built upon it. We appreciate the original authors for sharing code.