Skip to content

Latest commit

 

History

History
150 lines (105 loc) · 5.28 KB

README.md

File metadata and controls

150 lines (105 loc) · 5.28 KB

AoANet-Paddle

基于paddle框架的Attention on Attention for Image Captioning实现

一、简介

本项目基于paddle复现Attention on Attention for Image Captioning中所提出的Attention on Attention模型。该模型在传统的self-attention注意力机制的基础上,添加了gate机制以过滤和query不相关的attention信息。同时,作者还引入multi-head attention用于建模不同目标之间的关系。

注: AI Studio项目地址: https://aistudio.baidu.com/aistudio/projectdetail/3290242.

您可以使用AI Studio平台在线运行该项目!

论文:

  • [1] L. Huang, W. Wang, J. Chen, X. Wei, "Attention on Attention for Image Captioning", ICCV, 2019.

参考项目:

二、复现精度

所有指标均为模型在COCO2014的测试集评估而得

指标 BlEU-1 BlEU-2 BlEU-3 BlEU-4 METEOR ROUGE-L CIDEr-D SPICE
原论文 0.805 0.652 0.510 0.391 0.290 0.589 1.289 0.227
复现精度 0.802 0.648 0.504 0.385 0.286 0.585 1.271 0.222

三、数据集

本项目所使用的数据集为COCO2014。该数据集共包含123287张图像,每张图像对应5个标题。训练集、验证集和测试集分别为113287、5000、5000张图像及其对应的标题。本项目使用预提取的bottom-up特征,可以从这里下载得到(我们提供了脚本下载该数据集的标题以及图像特征,见download_dataset.sh)。

四、环境依赖

  • 硬件:CPU、GPU ( > 11G )

  • 软件:

    • Python 3.8
    • Java 1.8.0
    • PaddlePaddle == 2.1.0

五、快速开始

step1: clone

# clone this repo
git clone https://github.com/fuqianya/AoANet-Paddle.git --recursive
cd AoANet-Paddle

step2: 安装环境及依赖

pip install -r requirements.txt

step3: 下载数据

# 下载数据集及特征
bash ./download_dataset.sh
# 下载与计算评价指标相关的文件
bash ./coco-caption/get_google_word2vec_model.sh
bash ./coco-caption/get_stanford_models.sh

step4: 数据集预处理

python prepro.py

step5: 训练

训练过程过程分为两步(详情见论文3.3节):

  • Training with Cross Entropy (XE) Loss

    bash ./train_xe.sh
  • CIDEr-D Score Optimization

    bash ./train_rl.sh

step6: 测试

  • 测试train_xe阶段的模型

    python eval.py --model log/log_aoa/model.pdparams --infos_path log/log_aoa/infos_aoa.pkl --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test
  • 测试train_rl阶段的模型

    python eval.py --model log/log_aoa_rl/model.pdparams --infos_path log/log_aoa_rl/infos_aoa.pkl --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test

使用预训练模型进行预测

模型下载: 谷歌云盘

将下载的模型权重以及训练信息放到log目录下, 运行step6的指令进行测试。

六、Demo

我们提供一个Demo样例,详情见demo.ipynb

七、代码结构与详细说明

├── cider               # 计算评价指标工具
├── coco-caption        # 计算评价指标工具
├── config
│  └── config.py        # 模型的参数设置
├── data                # 预处理的数据
├── log                 # 存储训练模型及历史信息
├── model
│   └── AoAModel.py     # 定义模型结构
│   └── dataloader.py   # 加载训练数据
│   └── loss.py         # 定义损失函数
├── utils 
│   └── eval_utils.py   # 测试工具
│   └── utils.py        # 其他工具
├── download_dataset.sh # 数据集下载脚本
├── prepro.py           # 数据预处理
├── train.py            # 训练主函数
├── eval.py             # 测试主函数
├── train_xe.sh         # 训练脚本
├── train_rl.sh         # 训练脚本
└── requirement.txt     # 依赖包

模型、训练的所有参数信息都在config.py中进行了详细注释,详情见config/config.py

八、模型信息

关于模型的其他信息,可以参考下表:

信息 说明
发布者 fuqianya
时间 2021.08
框架版本 Paddle 2.1.0
应用场景 多模态
支持硬件 GPU、CPU
下载链接 预训练模型 | 训练日志