-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move buffers to device #10523
base: main
Are you sure you want to change the base?
Move buffers to device #10523
Conversation
for param_name, param in model.named_buffers(): | ||
if is_quantized and ( | ||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) | ||
): | ||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||
else: | ||
if accepts_dtype: | ||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) | ||
else: | ||
set_module_tensor_to_device(model, param_name, device, value=param) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Let's also have a test for this?
Cc: @SunMarc too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
persistant buffer should be in the state dict no ? It is for non persistant buffers ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for non-persistent buffers.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
when can we expect merge? thank you. sample code here : #10520 (comment) |
What does this PR do?
We are hitting this path using
pos_embed
as computed during module__init__
.pos_embed
is registered as a buffer, it appears that buffers aren't getting moved to device when we're using quantization and normally buffers are being moved byto
or offload hooks.diffusers/src/diffusers/models/embeddings.py
Line 567 in 52c05bd
diffusers/src/diffusers/models/embeddings.py
Line 513 in 52c05bd
SANA 4K example now runs however I got OOM on decode, needs DC-AE tiling PR to merge.
Code
This issue likely affects any model using
PatchEmbed
with quantization, and any other modules using buffers.Fixes #10520
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul