Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move buffers to device #10523

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

hlky
Copy link
Collaborator

@hlky hlky commented Jan 10, 2025

What does this PR do?

  File "diffusers/pipelines/sana/pipeline_sana.py", line 882, in __call__
    noise_pred = self.transformer(
  File "torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/models/transformers/sana_transformer.py", line 414, in forward
    hidden_states = self.patch_embed(hidden_states)
  File "torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/models/embeddings.py", line 569, in forward
    return (latent + pos_embed).to(latent.dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

We are hitting this path using pos_embed as computed during module __init__. pos_embed is registered as a buffer, it appears that buffers aren't getting moved to device when we're using quantization and normally buffers are being moved by to or offload hooks.

pos_embed = self.pos_embed

self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)

SANA 4K example now runs however I got OOM on decode, needs DC-AE tiling PR to merge.

Code

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel

quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
    subfolder="text_encoder",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)

pipeline = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
    text_encoder=text_encoder_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.bfloat16,
    device_map="balanced",
)

prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")

This issue likely affects any model using PatchEmbed with quantization, and any other modules using buffers.

Fixes #10520

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul

Comment on lines +249 to +259
for param_name, param in model.named_buffers():
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Let's also have a test for this?

Cc: @SunMarc too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

persistant buffer should be in the state dict no ? It is for non persistant buffers ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for non-persistent buffers.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@FurkanGozukara
Copy link

FurkanGozukara commented Jan 11, 2025

when can we expect merge? thank you. sample code here : #10520 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants