Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdamW implementation does not truly decouple learning rate and weight decay #1849

Open
leenachennuru opened this issue Oct 9, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@leenachennuru
Copy link

leenachennuru commented Oct 9, 2024

Describe the bug

AdamW implementation (see here) does not truly decouple the weight decay and learning rate parameters in line with the adamw paper. This coupling often complicates HP tuning as tuning the learning rate also changes the effective WD used to train the model.

The implementation computes the updates as

$w_{t} = (1- \eta_{\text{effective}} \lambda) w_{t-1} - \eta_{\text{effective}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$

where $\eta_{\text{effective}} = \eta_t \eta_{\text{max}}$ with $\eta_t$ denoting the scheduler and $\eta_{\text{max}}$ the max/base LR.

This clearly couples LR and WD and is not in line with the paper which proposes to compute the updates as

$w_{t} = (1- \eta_t \lambda) w_{t-1} - \eta_t \eta_{\text{max}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$

For easier and more intuitive tuning, it would be useful to enable the completely decoupled version of AdamW via the simple fix: $\lambda = (\eta_{\text{effective}} / \eta_{\text{max}}) \lambda$ with updates: $w_{t} = (1- \lambda) w_{t-1} - \eta_{\text{effective}} {\hat{m}_t}/{\sqrt{\hat{v}_t} + \epsilon}$.

Note: This bug also exists in implementations of AdamW in Pytorch and Optax and has already been highlighted a few times across different papers, libraries, and blogs. More links below for reference.

  1. Mosaic ML Library
  2. Optimi
  3. Paper: How to set AdamW's weight decay as you scale model and dataset size
  4. Fabian Schaipp's blog
@leenachennuru leenachennuru added the bug Something isn't working label Oct 9, 2024
@timmoon10
Copy link
Contributor

timmoon10 commented Oct 9, 2024

For better or for worse, I think "AdamW" now refers to the LR-coupled version. In addition to PyTorch and JAX, I see this formulation in Keras (and therefore TensorFlow), PaddlePaddle, and MXNet. If we implement a LR-decoupled variant, we should give it a new name or make it an opt-in option so we don't confuse users.

There has been a lot of discussion in other frameworks:

It seems PyTorch deliberately made the decision to use the LR-coupled variant, and that's percolated to the entire ecosystem.

@leenachennuru
Copy link
Author

Allowing the user to invoke the fully decoupled version via either option (opt-in or another name) would be helpful. Couple more references on the potential utility of independent WD below.

  1. Small-scale proxies for large-scale Transformer training instabilities
  2. A Large-Scale Exploration of µ-Transfer
  3. u-μP: The Unit-Scaled Maximal Update Parametrization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants