From 7fb496b4aea1b57777ea696bf6c00933e02c21db Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 <22251099@zju.edu.cn> Date: Thu, 21 Dec 2023 07:46:40 +0000 Subject: [PATCH 1/4] add ASGD Chinese documents --- docs/api/paddle/optimizer/ASGD_cn.rst | 115 ++++++++++++++++++ docs/api/paddle/optimizer/Overview_cn.rst | 3 +- docs/api_guides/low_level/optimizer.rst | 12 +- docs/api_guides/low_level/optimizer_en.rst | 12 +- .../optimizer/torch.optim.ASGD.md | 48 ++++++++ 5 files changed, 187 insertions(+), 3 deletions(-) create mode 100644 docs/api/paddle/optimizer/ASGD_cn.rst create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md diff --git a/docs/api/paddle/optimizer/ASGD_cn.rst b/docs/api/paddle/optimizer/ASGD_cn.rst new file mode 100644 index 00000000000..446013d4ef1 --- /dev/null +++ b/docs/api/paddle/optimizer/ASGD_cn.rst @@ -0,0 +1,115 @@ +.. _cn_api_paddle_optimizer_ASGD: + +ASGD +------------------------------- + +.. py:class:: paddle.optimizer.ASGD(learning_rate=0.001, batch_num=1, parameters=None, weight_decay=None, grad_clip=None, multi_precision=False, name=None) + +ASGD算法的优化器。有关详细信息,请参阅: + +`Minimizing Finite Sums with the Stochastic Average Gradient `_ 。 + + +.. math:: + + \begin{aligned} + &\hspace{0mm} d=0,\ y_i=0\ \textbf{for}\ i=1,2,...,n \\ + &\hspace{0mm} \textbf{for}\ \: m=0,1,...\ \textbf{do} \: \\ + &\hspace{5mm} i=m\ \%\ n \\ + &\hspace{5mm} d=d-y_i+f_i{}'(x) \\ + &\hspace{5mm} y_i=f_i{}'(x) \\ + &\hspace{5mm} x=x-learning\_rate(\frac{d}{\mathrm{min}(m+1,\ n)}+\lambda x) \\ + &\hspace{0mm} \textbf{end for} \\ + \end{aligned} + + +参数 +:::::::::::: + + - **learning_rate** (float|_LRScheduleri,可选) - 学习率,用于参数更新的计算。可以是一个浮点型值或者一个_LRScheduler 类。默认值为 0.001。 + - **batch_num** (int,可选) - 完成一个 epoch 所需迭代的次数。默认值为 1。 + - **parameters** (list,可选) - 指定优化器需要优化的参数。在动态图模式下必须提供该参数;在静态图模式下默认值为 None,这时所有的参数都将被优化。 + - **weight_decay** (float|Tensor,可选) - 权重衰减系数,是一个 float 类型或者 shape 为[1],数据类型为 float32 的 Tensor 类型。默认值为 None。 + - **grad_clip** (GradientClipBase,可选) – 梯度裁剪的策略,支持三种裁剪策略::ref:`paddle.nn.ClipGradByGlobalNorm ` 、 :ref:`paddle.nn.ClipGradByNorm ` 、 :ref:`paddle.nn.ClipGradByValue ` 。 + 默认值为 None,此时将不进行梯度裁剪。 + - **multi_precision** (bool,可选) – 在基于 GPU 设备的混合精度训练场景中,该参数主要用于保证梯度更新的数值稳定性。设置为 True 时,优化器会针对 FP16 类型参数保存一份与其值相等的 FP32 类型参数备份。梯度更新时,首先将梯度类型提升到 FP32,然后将其更新到 FP32 类型参数备份中。最后,更新后的 FP32 类型值会先转换为 FP16 类型,再赋值给实际参与计算的 FP16 类型参数。默认为 False。 + - **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。 + + +代码示例 +:::::::::::: + +COPY-FROM: paddle.optimizer.ASGD + + +方法 +:::::::::::: +step() +''''''''' + +.. note:: + + 该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。 + +执行一次优化器并进行参数更新。 + +**返回** + +无。 + +**代码示例** + +COPY-FROM: paddle.optimizer.ASGD.step + +minimize(loss, startup_program=None, parameters=None, no_grad_set=None) +''''''''' + +为网络添加反向计算过程,并根据反向计算所得的梯度,更新 parameters 中的 Parameters,最小化网络损失值 loss。 + +**参数** + + - **loss** (Tensor) - 需要最小化的损失值变量 + - **startup_program** (Program,可选) - 用于初始化 parameters 中参数的 :ref:`cn_api_paddle_static_Program`,默认值为 None,此时将使用 :ref:`cn_api_paddle_static_default_startup_program` 。 + - **parameters** (list,可选) - 待更新的 Parameter 或者 Parameter.name 组成的列表,默认值为 None,此时将更新所有的 Parameter。 + - **no_grad_set** (set,可选) - 不需要更新的 Parameter 或者 Parameter.name 组成的集合,默认值为 None。 + +**返回** + + tuple(optimize_ops, params_grads),其中 optimize_ops 为参数优化 OP 列表;param_grads 为由(param, param_grad)组成的列表,其中 param 和 param_grad 分别为参数和参数的梯度。在静态图模式下,该返回值可以加入到 ``Executor.run()`` 接口的 ``fetch_list`` 参数中,若加入,则会重写 ``use_prune`` 参数为 True,并根据 ``feed`` 和 ``fetch_list`` 进行剪枝,详见 ``Executor`` 的文档。 + + +**代码示例** + +COPY-FROM: paddle.optimizer.ASGD.minimize + +clear_grad() +''''''''' + +.. note:: + + 该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。 + + +清除需要优化的参数的梯度。 + +**代码示例** + +COPY-FROM: paddle.optimizer.ASGD.clear_grad + +get_lr() +''''''''' + +.. note:: + + 该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。 + +获取当前步骤的学习率。当不使用_LRScheduler 时,每次调用的返回值都相同,否则返回当前步骤的学习率。 + +**返回** + +float,当前步骤的学习率。 + + +**代码示例** + +COPY-FROM: paddle.optimizer.ASGD.get_lr diff --git a/docs/api/paddle/optimizer/Overview_cn.rst b/docs/api/paddle/optimizer/Overview_cn.rst index 75965328ac0..ef663a19386 100644 --- a/docs/api/paddle/optimizer/Overview_cn.rst +++ b/docs/api/paddle/optimizer/Overview_cn.rst @@ -24,13 +24,14 @@ paddle.optimizer 目录下包含飞桨框架支持的优化器算法相关的 AP " :ref:`Adam ` ", "Adam 优化器" " :ref:`Adamax ` ", "Adamax 优化器" " :ref:`AdamW ` ", "AdamW 优化器" + " :ref:`ASGD ` ", "ASGD 优化器" " :ref:`Lamb ` ", "Lamb 优化器" " :ref:`LBFGS ` ", "LBFGS 优化器" " :ref:`Momentum ` ", "Momentum 优化器" " :ref:`Optimizer ` ", "飞桨框架优化器基类" " :ref:`RMSProp ` ", "RMSProp 优化器" - " :ref:`SGD ` ", "SGD 优化器" " :ref:`Rprop ` ", "Rprop 优化器" + " :ref:`SGD ` ", "SGD 优化器" .. _about_lr: diff --git a/docs/api_guides/low_level/optimizer.rst b/docs/api_guides/low_level/optimizer.rst index 86d41a59aa8..3615605f288 100644 --- a/docs/api_guides/low_level/optimizer.rst +++ b/docs/api_guides/low_level/optimizer.rst @@ -97,4 +97,14 @@ API Reference 请参考 :ref:`cn_api_fluid_optimizer_ModelAverage` :code:`Rprop` 优化器,该方法考虑到不同权值参数的梯度的数量级可能相差很大,因此很难找到一个全局的学习步长。因此创新性地提出靠参数梯度的符号,动态的调节学习步长以加速优化过程的方法。 -API Reference 请参考 :ref:`cn_api_fluid_optimizer_Rprop` \ No newline at end of file +API Reference 请参考 :ref:`cn_api_fluid_optimizer_Rprop` + + + + +11.ASGD/ASGDOptimizer +----------------- + +:code:`ASGD` 优化器,是 `SGD` 以空间换时间的策略版本,是一种轨迹平均的随机优化方法。 `ASGD` 在 `SGD` 的基础上,增加了历史参数的平均值度量,让下降方向噪音的方差呈递减趋势下降,从而使得算法最终会以线性速度收敛于最优值。 + +API Reference 请参考 :ref:`cn_api_fluid_optimizer_ASGD` \ No newline at end of file diff --git a/docs/api_guides/low_level/optimizer_en.rst b/docs/api_guides/low_level/optimizer_en.rst index 796d1895a69..17c41318d16 100755 --- a/docs/api_guides/low_level/optimizer_en.rst +++ b/docs/api_guides/low_level/optimizer_en.rst @@ -96,4 +96,14 @@ API Reference: :ref:`api_fluid_optimizer_ModelAverage` :code:`Rprop` Optimizer, this method considers that the magnitude of gradients for different weight parameters may vary greatly, making it difficult to find a global learning step size. Therefore, an innovative method is proposed to accelerate the optimization process by dynamically adjusting the learning step size through the use of parameter gradient symbols. -API Reference: :ref:`api_fluid_optimizer_Rprop` \ No newline at end of file +API Reference: :ref:`api_fluid_optimizer_Rprop` + + + + +11.ASGD/ASGDOptimizer +----------------- + +:code:`ASGD` Optimizer, it is a strategy version of SGD that trades space for time, and is a stochastic optimization method with trajectory averaging. On the basis of SGD, ASGD adds a measure of the average value of historical parameters, making the variance of noise in the descending direction decrease in a decreasing trend, so that the algorithm will eventually converge to the optimal value at a linear speed. + +API Reference: :ref:`api_fluid_optimizer_ASGD` \ No newline at end of file diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md new file mode 100644 index 00000000000..b078d90c129 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md @@ -0,0 +1,48 @@ +## [ torch 参数更多 ]torch.optim.ASGD + +### [torch.optim.ASGD](https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html) + +```python +torch.optim.ASGD(params, + lr=0.01, + lambd=0.0001, + alpha=0.75, + t0=1000000.0, + weight_decay=0, + foreach=None, + maximize=False, + differentiable=False) +``` + +### [paddle.optimizer.ASGD](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/optimizer/ASGD_cn.html#cn-api-paddle-optimizer-asgd) + +```python +paddle.optimizer.ASGD(learning_rate=0.001, + batch_num=1, + parameters=None, + weight_decay=None, + grad_clip=None, + multi_precision=False, + name=None) +``` + +注:Pytorch 的 ASGD 是有问题的。 +Pytorch 相比 Paddle 支持更多其他参数,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------- | +| params | parameters | 表示指定优化器需要优化的参数,仅参数名不一致 | +| lr | learning_rate | 学习率,用于参数更新的计算。参数默认值不一致, Pytorch 默认为 `0.0001`, Paddle 默认为 `0.001`,Paddle 需保持与 Pytorch 一致 | +| lambd | - | 衰变项,与 weight_decay 功能重叠,暂无转写方式 | +| alpha | - | eta 更新的 power,暂无转写方式 | +| t0 | - | 开始求平均值的点,暂无转写方式 | +| weight_decay | weight_decay | 权重衰减。参数默认值不一致, Pytorch 默认为 `0`, Paddle 默认为 `None`,Paddle 需保持与 Pytorch 一致 | +| foreach | - | 是否使用优化器的 foreach 实现。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | +| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,暂无转写方式 | +| differentiable| - | 是否应通过训练中的优化器步骤进行自动微分。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | +| - | batch_num | 完成一个 epoch 所需迭代的次数。 PyTorch 无此参数,Paddle 需要根据样本数据设置 | +| - | grad_clip | 梯度裁剪的策略。 PyTorch 无此参数,Paddle 保持默认即可 | +| - | multi_precision | 在基于 GPU 设备的混合精度训练场景中,该参数主要用于保证梯度更新的数值稳定性。 PyTorch 无此参数,Paddle 保持默认即可 | +| - | name | 一般情况下无需设置。 PyTorch 无此参数,Paddle 保持默认即可 | \ No newline at end of file From b484003303e42c2171c724bef6795221ba28ad5d Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Wed, 24 Jan 2024 19:38:37 +0800 Subject: [PATCH 2/4] Update torch.optim.ASGD.md --- .../api_difference/optimizer/torch.optim.ASGD.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md index b078d90c129..bc3000a866f 100644 --- a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md @@ -42,7 +42,7 @@ Pytorch 相比 Paddle 支持更多其他参数,具体如下: | foreach | - | 是否使用优化器的 foreach 实现。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | | maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,暂无转写方式 | | differentiable| - | 是否应通过训练中的优化器步骤进行自动微分。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | -| - | batch_num | 完成一个 epoch 所需迭代的次数。 PyTorch 无此参数,Paddle 需要根据样本数据设置 | +| - | batch_num | 完成一个 epoch 所需迭代的次数。 PyTorch 无此参数。假设样本总数为 all_size,Paddle 需将 batch_num 设置为 all_size / batch_size | | - | grad_clip | 梯度裁剪的策略。 PyTorch 无此参数,Paddle 保持默认即可 | | - | multi_precision | 在基于 GPU 设备的混合精度训练场景中,该参数主要用于保证梯度更新的数值稳定性。 PyTorch 无此参数,Paddle 保持默认即可 | -| - | name | 一般情况下无需设置。 PyTorch 无此参数,Paddle 保持默认即可 | \ No newline at end of file +| - | name | 一般情况下无需设置。 PyTorch 无此参数,Paddle 保持默认即可 | From 2a3978ac59435227e60a755a381b64d181a4beba Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Thu, 25 Jan 2024 20:59:01 +0800 Subject: [PATCH 3/4] Update torch.optim.ASGD.md --- .../optimizer/torch.optim.ASGD.md | 47 ++++++++++++++++--- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md index bc3000a866f..07be8678949 100644 --- a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md @@ -1,4 +1,4 @@ -## [ torch 参数更多 ]torch.optim.ASGD +## [ 功能缺失 ]torch.optim.ASGD ### [torch.optim.ASGD](https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html) @@ -35,14 +35,49 @@ Pytorch 相比 Paddle 支持更多其他参数,具体如下: | ------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------- | | params | parameters | 表示指定优化器需要优化的参数,仅参数名不一致 | | lr | learning_rate | 学习率,用于参数更新的计算。参数默认值不一致, Pytorch 默认为 `0.0001`, Paddle 默认为 `0.001`,Paddle 需保持与 Pytorch 一致 | -| lambd | - | 衰变项,与 weight_decay 功能重叠,暂无转写方式 | -| alpha | - | eta 更新的 power,暂无转写方式 | -| t0 | - | 开始求平均值的点,暂无转写方式 | -| weight_decay | weight_decay | 权重衰减。参数默认值不一致, Pytorch 默认为 `0`, Paddle 默认为 `None`,Paddle 需保持与 Pytorch 一致 | +| lambd | - | 衰变项,与 weight_decay 功能重叠,可直接删除 | +| alpha | - | eta 更新的 power,可直接删除 | +| t0 | - | 开始求平均值的点,可直接删除 | +| weight_decay | weight_decay | 权重衰减。参数默认值不一致, Pytorch 默认为 `0`, Paddle 默认为 `None`,Paddle 需保持与 Pytorch 一致 | | foreach | - | 是否使用优化器的 foreach 实现。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | -| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,暂无转写方式 | +| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,可直接删除 | | differentiable| - | 是否应通过训练中的优化器步骤进行自动微分。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | | - | batch_num | 完成一个 epoch 所需迭代的次数。 PyTorch 无此参数。假设样本总数为 all_size,Paddle 需将 batch_num 设置为 all_size / batch_size | | - | grad_clip | 梯度裁剪的策略。 PyTorch 无此参数,Paddle 保持默认即可 | | - | multi_precision | 在基于 GPU 设备的混合精度训练场景中,该参数主要用于保证梯度更新的数值稳定性。 PyTorch 无此参数,Paddle 保持默认即可 | | - | name | 一般情况下无需设置。 PyTorch 无此参数,Paddle 保持默认即可 | + +### 相关问题 + +torch 当前版本的 ASGD 实现并不完善。转换过来的 paddle ASGD 会与 torch 的不一致(不影响收敛),但是可以正常使用。如果强需求保证转换前后一致,可以自行尝试其他优化器。 + +如果后续 torch 有代码更新,可以联系 @WintersMontagne10335 作 API 调整与对接。 + +#### torch 现存问题 + +在 `_single_tensor_asgd` 中,对 `axs, ax` 进行了更新,但是它们却并没有参与到 `params` 中。 `axs, ax` 完全没有作用。 + +调研到的比较可信的原因是,目前 `ASGD` 的功能并不完善, `axs, ax` 是预留给以后的版本的。 + +另外,weight_decay 是冗余的。 + +当前版本 `ASGD` 的功能,类似于 `SGD` 。 + +详情可见: +- https://discuss.pytorch.org/t/asgd-optimizer-has-a-bug/95060 +- https://discuss.pytorch.org/t/averaged-sgd-implementation/26960 +- https://github.com/pytorch/pytorch/issues/74884 + +#### paddle 实现思路 + +主要参照 [`ASGD` 论文: Minimizing Finite Sums with the Stochastic Average Gradient](https://inria.hal.science/hal-00860051v2) + +核心步骤为: + + 1. 初始化 d, y + 2. 随机采样 + 3. 用本次计算得到的第 i 个样本的梯度信息,替换上一次的梯度信息 + 4. 更新参数 + +伪代码和详细实现步骤可见: +- https://github.com/PaddlePaddle/community/blob/b76313c3b8f8b6a2f808d90fa95dcf265dbef67d/rfcs/APIs/20231111_api_design_for_ASGD.md From 254640ab6ffd48bacef8323336c0144d89477cf9 Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:15:25 +0800 Subject: [PATCH 4/4] Update torch.optim.ASGD.md --- .../api_difference/optimizer/torch.optim.ASGD.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md index 07be8678949..0024e0007f0 100644 --- a/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md @@ -1,4 +1,4 @@ -## [ 功能缺失 ]torch.optim.ASGD +## [ torch 参数更多 ]torch.optim.ASGD ### [torch.optim.ASGD](https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html) @@ -40,7 +40,7 @@ Pytorch 相比 Paddle 支持更多其他参数,具体如下: | t0 | - | 开始求平均值的点,可直接删除 | | weight_decay | weight_decay | 权重衰减。参数默认值不一致, Pytorch 默认为 `0`, Paddle 默认为 `None`,Paddle 需保持与 Pytorch 一致 | | foreach | - | 是否使用优化器的 foreach 实现。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | -| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,可直接删除 | +| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,暂无转写方式 | | differentiable| - | 是否应通过训练中的优化器步骤进行自动微分。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | | - | batch_num | 完成一个 epoch 所需迭代的次数。 PyTorch 无此参数。假设样本总数为 all_size,Paddle 需将 batch_num 设置为 all_size / batch_size | | - | grad_clip | 梯度裁剪的策略。 PyTorch 无此参数,Paddle 保持默认即可 |