Skip to content

这是论文Unsupervised Domain Adaptation by Backpropagation的复现tf2代码,并完成了MNIST与MNIST-M数据集之间的迁移训练

Notifications You must be signed in to change notification settings

Daipuwei/DANN-MNIST-tf2

Repository files navigation

DANN-MNIST-tf2

这是论文Unsupervised Domain Adaptation by Backpropagation的复现代码,完成了MNIST与MNIST-M数据集之间的迁移训练

实验环境

  1. tensorflow=2.4.0
  2. opencv
  3. numpy
  4. pickle
  5. skimage

文档结构

  • checkpoints存放训练过程中模型权重;
  • logs存放模型训练过程中相关日志文件;
  • config存放参数配置类脚本及训练过程中参数配置文件;
  • model存放网络模型定义脚本;
  • model_data存放包括但不限于数据集、预训练模型等文件;
  • utils存放包括但不限于数据集和模型训练相关工具类和工具脚本;
  • image存放tensorboard可视化截图;
  • create_mnistm.py是根据MNIST数据集生成MNIST-M数据集的脚本;
  • train_MNIST2MNIST_M.py是利用MNIST和MNIST-M数据集进行DANN自适应模型训练的脚本;

How to train

首先下载BSDS500数据集 ,放在model_data/dataset路径下。其下载路径如下:

然后执行python create_mnistm.py生成MNIST-M数据集,根据自己需要修改create_mnistm.pyBST_PATHmnist_dirmnistm _dir ,默认路径如下:

BST_PATH = os.path.abspath('./model_data/dataset/BSR_bsds500.tgz')
mnist_dir = os.path.abspath("model_data/dataset/MNIST")
mnistm_dir = os.path.abspath("model_data/dataset/MNIST_M")

最后运行如下命令进行MNIST和MNIST-M数据集之间的自适应模型训练,根据自己的需要进行修改相关超参数,例如init_learning_ratemomentum_ratebatch_sizeepochpre_model_pathsource_dataset_pathtarget_dataset_path

python train_MNIST2MNIST_M.py

实验结果

下面主要包括了MNIST和MNIST-M数据集在自适应训练过程中学习率梯度反转层参数$\lambda$、训练集和验证集的图像分类损失域分类损失图像分类精度域分类精度模型总损失的可视化。

首先是超参数学习率梯度反转层参数$\lambda$在训练过程中的数据可视化。

超参数可视化

接着是训练数据集和验证数据集的图像分类精度域分类精度在训练过程中的数据可视化,其中蓝色代表训练集,红色代表验证集。

指标可视化

最后是训练数据集和验证数据集的图像分类损失域分类损失在训练过程中的数据可视化,其中蓝色代表训练集,红色代表验证集。

损失可视化

相关博客资料

CSDN博客链接:

  1. 【深度域自适应】一、DANN与梯度反转层(GRL)详解
  2. 【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练

知乎专栏链接:

  1. 【深度域自适应】一、DANN与梯度反转层(GRL)详解
  2. 【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练

About

这是论文Unsupervised Domain Adaptation by Backpropagation的复现tf2代码,并完成了MNIST与MNIST-M数据集之间的迁移训练

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages