Please download the dataset from Huggingface Datasets and put the dataset in the ./datasets/nextqa
directory, running:
python examples/download_data_hf.py --hf_root rhymes-ai/NeXTVideo --save_root ./datasets/nextqa
Then please unzip the .zip
files(including images and videos) inside each sub-folder.
cd ./datasets/nextqa
unzip NExTVideo.zip
The LoRA training configuration is shown in config_lora.yaml. Please modify your customized path of Aria model, Aria tokenizer and the nlvr2 dataset. This setting can run well on single A100 80GB using 4k sequence length due to longer visual context. We set the max_image_size
to 490 for video datasets.
Note: In this configuration, we add LoRA on all modules in the LLM of Aria, without the vit and projector. If you want to add LoRA on vit/projector, you can adjust the
freeze_vit
orfreeze_projector
. You can also adjustlora_target_modules
to choose the sub-modules of LLM blocks andfreeze_llm_layers
to set the layers where you don't want to add LoRA.
Command (on single 80GB A100):
CUDA_VISIBLE_DEVICES=0 python aria/train.py --config examples/nextqa/config_lora.yaml --output_dir [YOUR_OUT_DIR]
Full parameter finetuning is feasible with 8 H100 GPUs, using ZeRO3
and Offload Parameter
. The command is as following:
accelerate launch --config_file recipes/accelerate_configs/zero3_offload.yaml aria/train.py --config examples/nextqa/config_full.yaml --output_dir [YOUR_OUT_DIR]
Note: If you train full params with DeepSpeed ZeRO, you need to extract the fp32 consolidated weights from ZeRO 1, 2, or 3 DeepSpeed checkpoints:
cd /path/to/your/output/dir python zero_to_fp32.py . pytorch_model.bin
After modifying the dataset paths in NextQA-Evaluation, run::
CUDA_VISIBLE_DEVICES=0 python examples/nextqa/evaluation.py \
--base_model_path [YOUR_ARIA_PATH] \
--tokenizer_path [YOUR_ARIA_TOKENIZER_PATH] \
--save_root [YOUR_SAVE_PATH] \
--image_size [490] \
--peft_model_path [YOUR_LORA_PATH] # OPTIONAL
The Accuracy
:
Aria | LoRA SFT | Full Params SFT |
---|---|---|
78.14 | 80.80 | 80.08 |
These are the loss curves of LoRA SFT
and Full Params SFT
: