-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ADD] EKFAC #127
base: main
Are you sure you want to change the base?
[ADD] EKFAC #127
Conversation
Pull Request Test Coverage Report for Build 10975103378Details
💛 - Coveralls |
@f-dangel One thing that is not tested and that could be wrong is the per-example gradient computation when there is weight sharing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gave some refactoring comments.
Overall, while reading through the diff, I was wondering if there is a better way to separate the eigenvalue correction of EKFAC. Ideally, I was imagining we can keep KFAC
as is and implement EKFAC
separately, e.g. by inheriting EKFAC
from KFAC
.
Do you have a good idea how to do this? Otherwise I believe this PR will make the code a lot more complex, and long-term complicate extending KFAC
, especially for developers that are less familiar with EKFAC
.
# Delete the cached activations | ||
self._cached_activations.clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these cached activations concatenated over batches? Why don't they have to be cleared inside the data loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No they will just be overwritten, this avoids redundant clearing of the cache before it is filled up again anyway. Do you think it is cleaner to clear the cache explicitly every iteration?
"d_out1 d_out2, ... d_out1 d_in1, d_in1 d_in2 -> ... d_out2 d_in2", | ||
) | ||
.square_() | ||
.sum(dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this sum correct, or do you want to sum out the ...
of the einsum
result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the above variable, I would change .sum(dim=0)
into .sum(list(range(shared_axes)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also check the else
branch below for the same suggestions.
per_example_gradient = einsum( | ||
g, | ||
self._cached_activations[module_name], | ||
"shared d_out, shared d_in -> shared d_out d_in", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shared
should be replaced by ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, add a line shared_axes = g.ndim - 2
.
Will continue this PR in ~2 weeks. |
Implements EKFAC (and its inverse) support (resolves #116).
I think we should at some point refactor
KFACLinearOperator
andKFACInverseLinearOperator
to inherit fromKroneckerProductLinearOperator
andEigendecomposedKroneckerProductLinearOperator
(or similar) classes sincetorch_matmat
and other methods can be shared. Also, currentlyKFACInverseLinearOperator
doesn't support trace, det, etc. properties which can also be shared. I created #126 for this.