-
Notifications
You must be signed in to change notification settings - Fork 753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.40】为 Paddle 新增 ASGD API 中文文档 #6412
Merged
luotao1
merged 4 commits into
PaddlePaddle:develop
from
WintersMontagne10335:winters018
Jan 26, 2024
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
.. _cn_api_paddle_optimizer_ASGD: | ||
|
||
ASGD | ||
------------------------------- | ||
|
||
.. py:class:: paddle.optimizer.ASGD(learning_rate=0.001, batch_num=1, parameters=None, weight_decay=None, grad_clip=None, multi_precision=False, name=None) | ||
|
||
ASGD算法的优化器。有关详细信息,请参阅: | ||
|
||
`Minimizing Finite Sums with the Stochastic Average Gradient <https://hal.science/hal-00860051v2>`_ 。 | ||
|
||
|
||
.. math:: | ||
|
||
\begin{aligned} | ||
&\hspace{0mm} d=0,\ y_i=0\ \textbf{for}\ i=1,2,...,n \\ | ||
&\hspace{0mm} \textbf{for}\ \: m=0,1,...\ \textbf{do} \: \\ | ||
&\hspace{5mm} i=m\ \%\ n \\ | ||
&\hspace{5mm} d=d-y_i+f_i{}'(x) \\ | ||
&\hspace{5mm} y_i=f_i{}'(x) \\ | ||
&\hspace{5mm} x=x-learning\_rate(\frac{d}{\mathrm{min}(m+1,\ n)}+\lambda x) \\ | ||
&\hspace{0mm} \textbf{end for} \\ | ||
\end{aligned} | ||
|
||
|
||
参数 | ||
:::::::::::: | ||
|
||
- **learning_rate** (float|_LRScheduleri,可选) - 学习率,用于参数更新的计算。可以是一个浮点型值或者一个_LRScheduler 类。默认值为 0.001。 | ||
- **batch_num** (int,可选) - 完成一个 epoch 所需迭代的次数。默认值为 1。 | ||
- **parameters** (list,可选) - 指定优化器需要优化的参数。在动态图模式下必须提供该参数;在静态图模式下默认值为 None,这时所有的参数都将被优化。 | ||
- **weight_decay** (float|Tensor,可选) - 权重衰减系数,是一个 float 类型或者 shape 为[1],数据类型为 float32 的 Tensor 类型。默认值为 None。 | ||
- **grad_clip** (GradientClipBase,可选) – 梯度裁剪的策略,支持三种裁剪策略::ref:`paddle.nn.ClipGradByGlobalNorm <cn_api_paddle_nn_ClipGradByGlobalNorm>` 、 :ref:`paddle.nn.ClipGradByNorm <cn_api_paddle_nn_ClipGradByNorm>` 、 :ref:`paddle.nn.ClipGradByValue <cn_api_paddle_nn_ClipGradByValue>` 。 | ||
默认值为 None,此时将不进行梯度裁剪。 | ||
- **multi_precision** (bool,可选) – 在基于 GPU 设备的混合精度训练场景中,该参数主要用于保证梯度更新的数值稳定性。设置为 True 时,优化器会针对 FP16 类型参数保存一份与其值相等的 FP32 类型参数备份。梯度更新时,首先将梯度类型提升到 FP32,然后将其更新到 FP32 类型参数备份中。最后,更新后的 FP32 类型值会先转换为 FP16 类型,再赋值给实际参与计算的 FP16 类型参数。默认为 False。 | ||
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。 | ||
|
||
|
||
代码示例 | ||
:::::::::::: | ||
|
||
COPY-FROM: paddle.optimizer.ASGD | ||
|
||
|
||
方法 | ||
:::::::::::: | ||
step() | ||
''''''''' | ||
|
||
.. note:: | ||
|
||
该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。 | ||
|
||
执行一次优化器并进行参数更新。 | ||
|
||
**返回** | ||
|
||
无。 | ||
|
||
**代码示例** | ||
|
||
COPY-FROM: paddle.optimizer.ASGD.step | ||
|
||
minimize(loss, startup_program=None, parameters=None, no_grad_set=None) | ||
''''''''' | ||
|
||
为网络添加反向计算过程,并根据反向计算所得的梯度,更新 parameters 中的 Parameters,最小化网络损失值 loss。 | ||
|
||
**参数** | ||
|
||
- **loss** (Tensor) - 需要最小化的损失值变量 | ||
- **startup_program** (Program,可选) - 用于初始化 parameters 中参数的 :ref:`cn_api_paddle_static_Program`,默认值为 None,此时将使用 :ref:`cn_api_paddle_static_default_startup_program` 。 | ||
- **parameters** (list,可选) - 待更新的 Parameter 或者 Parameter.name 组成的列表,默认值为 None,此时将更新所有的 Parameter。 | ||
- **no_grad_set** (set,可选) - 不需要更新的 Parameter 或者 Parameter.name 组成的集合,默认值为 None。 | ||
|
||
**返回** | ||
|
||
tuple(optimize_ops, params_grads),其中 optimize_ops 为参数优化 OP 列表;param_grads 为由(param, param_grad)组成的列表,其中 param 和 param_grad 分别为参数和参数的梯度。在静态图模式下,该返回值可以加入到 ``Executor.run()`` 接口的 ``fetch_list`` 参数中,若加入,则会重写 ``use_prune`` 参数为 True,并根据 ``feed`` 和 ``fetch_list`` 进行剪枝,详见 ``Executor`` 的文档。 | ||
|
||
|
||
**代码示例** | ||
|
||
COPY-FROM: paddle.optimizer.ASGD.minimize | ||
|
||
clear_grad() | ||
''''''''' | ||
|
||
.. note:: | ||
|
||
该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。 | ||
|
||
|
||
清除需要优化的参数的梯度。 | ||
|
||
**代码示例** | ||
|
||
COPY-FROM: paddle.optimizer.ASGD.clear_grad | ||
|
||
get_lr() | ||
''''''''' | ||
|
||
.. note:: | ||
|
||
该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。 | ||
|
||
获取当前步骤的学习率。当不使用_LRScheduler 时,每次调用的返回值都相同,否则返回当前步骤的学习率。 | ||
|
||
**返回** | ||
|
||
float,当前步骤的学习率。 | ||
|
||
|
||
**代码示例** | ||
|
||
COPY-FROM: paddle.optimizer.ASGD.get_lr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
...model_convert/convert_from_pytorch/api_difference/optimizer/torch.optim.ASGD.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
## [ torch 参数更多 ]torch.optim.ASGD | ||
|
||
### [torch.optim.ASGD](https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html) | ||
|
||
```python | ||
torch.optim.ASGD(params, | ||
lr=0.01, | ||
lambd=0.0001, | ||
alpha=0.75, | ||
t0=1000000.0, | ||
weight_decay=0, | ||
foreach=None, | ||
maximize=False, | ||
differentiable=False) | ||
``` | ||
|
||
### [paddle.optimizer.ASGD](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/optimizer/ASGD_cn.html#cn-api-paddle-optimizer-asgd) | ||
|
||
```python | ||
paddle.optimizer.ASGD(learning_rate=0.001, | ||
batch_num=1, | ||
parameters=None, | ||
weight_decay=None, | ||
grad_clip=None, | ||
multi_precision=False, | ||
name=None) | ||
``` | ||
|
||
注:Pytorch 的 ASGD 是有问题的。 | ||
Pytorch 相比 Paddle 支持更多其他参数,具体如下: | ||
|
||
### 参数映射 | ||
|
||
| PyTorch | PaddlePaddle | 备注 | | ||
| ------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------- | | ||
| params | parameters | 表示指定优化器需要优化的参数,仅参数名不一致 | | ||
| lr | learning_rate | 学习率,用于参数更新的计算。参数默认值不一致, Pytorch 默认为 `0.0001`, Paddle 默认为 `0.001`,Paddle 需保持与 Pytorch 一致 | | ||
| lambd | - | 衰变项,与 weight_decay 功能重叠,可直接删除 | | ||
| alpha | - | eta 更新的 power,可直接删除 | | ||
| t0 | - | 开始求平均值的点,可直接删除 | | ||
| weight_decay | weight_decay | 权重衰减。参数默认值不一致, Pytorch 默认为 `0`, Paddle 默认为 `None`,Paddle 需保持与 Pytorch 一致 | | ||
| foreach | - | 是否使用优化器的 foreach 实现。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | | ||
| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,暂无转写方式 | | ||
| differentiable| - | 是否应通过训练中的优化器步骤进行自动微分。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除 | | ||
| - | batch_num | 完成一个 epoch 所需迭代的次数。 PyTorch 无此参数。假设样本总数为 all_size,Paddle 需将 batch_num 设置为 all_size / batch_size | | ||
| - | grad_clip | 梯度裁剪的策略。 PyTorch 无此参数,Paddle 保持默认即可 | | ||
| - | multi_precision | 在基于 GPU 设备的混合精度训练场景中,该参数主要用于保证梯度更新的数值稳定性。 PyTorch 无此参数,Paddle 保持默认即可 | | ||
| - | name | 一般情况下无需设置。 PyTorch 无此参数,Paddle 保持默认即可 | | ||
|
||
### 相关问题 | ||
|
||
torch 当前版本的 ASGD 实现并不完善。转换过来的 paddle ASGD 会与 torch 的不一致(不影响收敛),但是可以正常使用。如果强需求保证转换前后一致,可以自行尝试其他优化器。 | ||
|
||
如果后续 torch 有代码更新,可以联系 @WintersMontagne10335 作 API 调整与对接。 | ||
|
||
#### torch 现存问题 | ||
|
||
在 `_single_tensor_asgd` 中,对 `axs, ax` 进行了更新,但是它们却并没有参与到 `params` 中。 `axs, ax` 完全没有作用。 | ||
|
||
调研到的比较可信的原因是,目前 `ASGD` 的功能并不完善, `axs, ax` 是预留给以后的版本的。 | ||
|
||
另外,weight_decay 是冗余的。 | ||
|
||
当前版本 `ASGD` 的功能,类似于 `SGD` 。 | ||
|
||
详情可见: | ||
- https://discuss.pytorch.org/t/asgd-optimizer-has-a-bug/95060 | ||
- https://discuss.pytorch.org/t/averaged-sgd-implementation/26960 | ||
- https://github.com/pytorch/pytorch/issues/74884 | ||
|
||
#### paddle 实现思路 | ||
|
||
主要参照 [`ASGD` 论文: Minimizing Finite Sums with the Stochastic Average Gradient](https://inria.hal.science/hal-00860051v2) | ||
|
||
核心步骤为: | ||
|
||
1. 初始化 d, y | ||
2. 随机采样 | ||
3. 用本次计算得到的第 i 个样本的梯度信息,替换上一次的梯度信息 | ||
4. 更新参数 | ||
|
||
伪代码和详细实现步骤可见: | ||
- https://github.com/PaddlePaddle/community/blob/b76313c3b8f8b6a2f808d90fa95dcf265dbef67d/rfcs/APIs/20231111_api_design_for_ASGD.md |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那目前这两个API实现差异还是比较大的,对
torch.optim.ASGD
直接转成paddle.optimizer.ASGD
结果肯定是对不上的,所以在映射文档里写明白原因吧:torch的问题、为何实现不一致、如果使用paddle的ASGD结果会对不上但不一定影响最终收敛,或者自行尝试其他优化器,让用户知道这里有坑不容易对齐。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
然后就按 功能缺失 来处理吧,后面torch如果更新了再调整
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改了一下,您看可以嘛