-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose mixed_precision dtype arguments #348
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this.
torchtitan/config_manager.py
Outdated
TORCH_DTYPE_ARGS = [ | ||
"checkpoint.export_dtype", | ||
"training.mixed_precision_param", | ||
"training.mixed_precision_reduce", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should "reduce" be "grad"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, im following the existing naming for the mixed_precision config struct
torch dtype to use for reductions when applying mixed precision via FSDP. | ||
This feature only takes effect when data_parallel_degree > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "reductions" -> "gradients"
torchtitan/config_manager.py
Outdated
for k_, v_ in v.items(): | ||
if ".".join([k, k_]) in TORCH_DTYPE_ARGS: | ||
v[k_] = torch_dtype(v_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
add training.mixed_precision_param and .mixed_precision_reduce options refactor a util to map strings to torch dtypes ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796 Pull Request resolved: #348
add training.mixed_precision_param and .mixed_precision_reduce options refactor a util to map strings to torch dtypes ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796 Pull Request resolved: #348
add training.mixed_precision_param and .mixed_precision_reduce options refactor a util to map strings to torch dtypes ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796 Pull Request resolved: pytorch#348
add training.mixed_precision_param and .mixed_precision_reduce options refactor a util to map strings to torch dtypes ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796 Pull Request resolved: #348
add training.mixed_precision_param and .mixed_precision_reduce options refactor a util to map strings to torch dtypes ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796 Pull Request resolved: pytorch#348
Stack from ghstack (oldest at bottom):
add training.mixed_precision_param and .mixed_precision_reduce options
refactor a util to map strings to torch dtypes