-
Notifications
You must be signed in to change notification settings - Fork 15
Feature/44 make flash attention configurable #47
base: develop
Are you sure you want to change the base?
Feature/44 make flash attention configurable #47
Conversation
* fix: change pre-cmmit autoupdate schedule to monthly * fix: change the merge strategy for Changelog to Union * fix: add .envrc to .gitignore * ci: ignore pre-commit-config and readthedocs for changelog updates * ci: fix to correct hpc workflow call * fix: update precommit config * chore: update pre-commits * feat: add codeowners file * chore: update dependencies * ci: add hpc-config * docs: changelog * fix: respond to review comments --------- Co-authored-by: Jesper Dramsch <[email protected]>
* feat: add configurability to dropout in MultiHeadSelfAttention Co-authored-by: Rilwan (Akanni) Adewoyin <[email protected]> * test: adjust to dropout_p * doc: update changelog * Feature/integrate reusable workflows (#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (#20) * ci: inherit pypi publish flow (#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <[email protected]> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <[email protected]> * Update CHANGELOG.md to KeepChangelog format * [pre-commit.ci] pre-commit autoupdate (#25) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0) - [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2) - [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Ci/changelog-release-updater (#26) * ci: add changelof release updater * docs: update changelog * Feature/integrate reusable workflows (#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (#20) * ci: inherit pypi publish flow (#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <[email protected]> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <[email protected]> * Update CHANGELOG.md to KeepChangelog format * Ci/changelog-release-updater (#26) * ci: add changelof release updater * docs: update changelog --------- Co-authored-by: Rilwan (Akanni) Adewoyin <[email protected]> Co-authored-by: Gert Mertes <[email protected]> Co-authored-by: Mario Santa Cruz <[email protected]> Co-authored-by: Jesper Dramsch <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
xfail for MultiHeadSelfAttention
for more information, see https://pre-commit.ci
a080cc5
to
d4940e7
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #47 +/- ##
========================================
Coverage 99.84% 99.84%
========================================
Files 23 23
Lines 1277 1304 +27
========================================
+ Hits 1275 1302 +27
Misses 2 2 ☔ View full report in Codecov by Sentry. |
Where's the PR template? |
….com:ecmwf/anemoi-models into feature/44-make-flash-attention-configurable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution. This should make the code much more stable in inference.
I left a few comments about implementation details that caught my eye and the documentation needs to be updated.
Additionally, now that it's configurable, is there a way to also change it in the config?
I assume it's through the instantiation. Should we add it to the config then to make it explicitly available?
When the errors are implemented correctly, please also add tests that make sure the errors are triggered correctly, so we can catch edge-cases.
self.attention = torch.compile(self.attention) | ||
self.is_attn_compiled = True | ||
|
||
# TODO test how this impacts scaling at large model counts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Who is this a TODO for?
Tensor | ||
aLiBi slopes | ||
""" | ||
n = 2 ** math.floor(math.log2(num_heads)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since num_heads
is an integer, we could be using bit-shifting here:
n = 1 << (num_heads.bit_length() - 1)
Not sure how necessary speed is here though, as a trade-off against readability. It would definitely need a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speed is not an issue as it is only calculated once. So, I would go for readability.
A predefined string which selects which underlying attention | ||
implementation, by default "flash_attention" | ||
softcap : float, optional | ||
Anything > 0 activates softcapping flash attention, by default None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "Anything > 0" mean here? Please adjust this explanation across docstrings to be more informative to someone that hasn't worked with the attention implementation yet.
4a99a5e
to
2d122df
Compare
e96cfd1
to
d4510f6
Compare
….com:ecmwf/anemoi-models into feature/44-make-flash-attention-configurable
Current setup:
Now:
This PR will be accompanied by changes to the config in Anemoi-training (PR)
Todo:
📚 Documentation preview 📚: https://anemoi-models--47.org.readthedocs.build/en/47/