Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux - torchao inference not working #10470

Open
nitinmukesh opened this issue Jan 6, 2025 · 11 comments
Open

Flux - torchao inference not working #10470

nitinmukesh opened this issue Jan 6, 2025 · 11 comments
Labels
bug Something isn't working

Comments

@nitinmukesh
Copy link

nitinmukesh commented Jan 6, 2025

Describe the bug

  1. Flux with torchao int8wo not working
  2. enable_sequential_cpu_offload not working

image

Reproduction

example taken from (merged)
#10009

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
# pipe.to("cuda")

# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()

prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
image.save("output.png")

Logs

Stuck at this (without cpu offload)

(venv) C:\ai1\diffuser_t2i>python FLUX_torchao.py
Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
Loading pipeline components...:  29%|████████▊                      | 2/7 [00:00<00:00,  5.36it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:00<00:00,  6.86it/s]
Loading pipeline components...: 100%|███████████████████████████████| 7/7 [00:02<00:00,  2.38it/s]

(with cpu offload)

(venv) C:\ai1\diffuser_t2i>python FLUX_torchao.py
Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:00<00:00,  6.98it/s]
Loading pipeline components...:  29%|████████▊                      | 2/7 [00:00<00:01,  2.62it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|███████████████████████████████| 7/7 [00:01<00:00,  4.31it/s]
Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\FLUX_torchao.py", line 21, in <module>
    pipe.enable_sequential_cpu_offload()
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\pipelines\pipeline_utils.py", line 1179, in enable_sequential_cpu_offload
    cpu_offload(model, device, offload_buffers=offload_buffers)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\big_modeling.py", line 205, in cpu_offload
    attach_align_device_hook(
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\hooks.py", line 518, in attach_align_device_hook
    attach_align_device_hook(
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\hooks.py", line 518, in attach_align_device_hook
    attach_align_device_hook(
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\hooks.py", line 518, in attach_align_device_hook
    attach_align_device_hook(
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\hooks.py", line 509, in attach_align_device_hook
    add_hook_to_module(module, hook, append=True)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\hooks.py", line 161, in add_hook_to_module
    module = hook.init_hook(module)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\hooks.py", line 308, in init_hook
    set_module_tensor_to_device(module, name, "meta")
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\accelerate\utils\modeling.py", line 355, in set_module_tensor_to_device
    new_value.layout_tensor,
AttributeError: 'AffineQuantizedTensor' object has no attribute 'layout_tensor'

System Info

Windows 11

(venv) C:\ai1\diffuser_t2i>python --version
Python 3.10.11

(venv) C:\ai1\diffuser_t2i>echo %CUDA_PATH%
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6

(venv) C:\ai1\diffuser_t2i>pip list
Package            Version
------------------ ------------
accelerate         1.2.0.dev0
aiofiles           23.2.1
annotated-types    0.7.0
anyio              4.7.0
bitsandbytes       0.45.0
certifi            2024.12.14
charset-normalizer 3.4.1
click              8.1.8
colorama           0.4.6
diffusers          0.33.0.dev0
einops             0.8.0
exceptiongroup     1.2.2
fastapi            0.115.6
ffmpy              0.5.0
filelock           3.16.1
fsspec             2024.12.0
gguf               0.13.0
gradio             5.9.1
gradio_client      1.5.2
h11                0.14.0
httpcore           1.0.7
httpx              0.28.1
huggingface-hub    0.25.2
idna               3.10
imageio            2.36.1
imageio-ffmpeg     0.5.1
importlib_metadata 8.5.0
Jinja2             3.1.5
markdown-it-py     3.0.0
MarkupSafe         2.1.5
mdurl              0.1.2
mpmath             1.3.0
networkx           3.4.2
ninja              1.11.1.3
numpy              2.2.1
opencv-python      4.10.0.84
orjson             3.10.13
packaging          24.2
pandas             2.2.3
pillow             11.1.0
pip                23.0.1
protobuf           5.29.2
psutil             6.1.1
pydantic           2.10.4
pydantic_core      2.27.2
pydub              0.25.1
Pygments           2.18.0
python-dateutil    2.9.0.post0
python-multipart   0.0.20
pytz               2024.2
PyYAML             6.0.2
regex              2024.11.6
requests           2.32.3
rich               13.9.4
ruff               0.8.6
safehttpx          0.1.6
safetensors        0.5.0
semantic-version   2.10.0
sentencepiece      0.2.0
setuptools         65.5.0
shellingham        1.5.4
six                1.17.0
sniffio            1.3.1
starlette          0.41.3
sympy              1.13.1
tokenizers         0.21.0
tomlkit            0.13.2
torch              2.5.1+cu124
torchao            0.7.0
torchvision        0.20.1+cu124
tqdm               4.67.1
transformers       4.47.1
typer              0.15.1
typing_extensions  4.12.2
tzdata             2024.2
urllib3            2.3.0
uvicorn            0.34.0
websockets         14.1
wheel              0.45.1
zipp               3.21.0

Who can help?

No response

@nitinmukesh nitinmukesh added the bug Something isn't working label Jan 6, 2025
@a-r-r-o-w
Copy link
Member

Thanks for reporting! Since layout_tensor was made an internal private attribute in TorchAO in version 0.7.0, it seems like we need to update how we handle it in accelerate (which is what's used for sequential offloading). I'll open a fix soon

@nitinmukesh
Copy link
Author

@a-r-r-o-w

Thank you for looking into this.
The inference code does not work without enable_sequential_cpu_offload. It just hangs at Loading pipeline component.

I also tried saving the quantized model locally

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "single_model"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    sharded_model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)
transformer.save_pretrained(single_model_path, max_shard_size="100GB")

There is error

(venv) C:\ai1\diffuser_t2i>python FLUX_int8.py
Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
C:\ai1\diffuser_t2i\venv\lib\site-packages\torchao\utils.py:434: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return func(*args, **kwargs)
Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 13, in storage_ptr
    return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\FLUX_int8.py", line 14, in <module>
    transformer.save_pretrained(single_model_path, max_shard_size="100GB")
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\modeling_utils.py", line 434, in save_pretrained
    safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 286, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 481, in _flatten
    shared_pointers = _find_shared_tensors(tensors)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 72, in _find_shared_tensors
    if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 17, in storage_ptr
    return tensor.storage().data_ptr()
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\storage.py", line 1224, in data_ptr
    return self._data_ptr()
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\storage.py", line 1228, in _data_ptr
    return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

Posting here as both issues are related to same model and quantization.

@zhangvia
Copy link

zhangvia commented Jan 8, 2025

same error here

(venv) C:\ai1\diffuser_t2i>python FLUX_int8.py
Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
C:\ai1\diffuser_t2i\venv\lib\site-packages\torchao\utils.py:434: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return func(*args, **kwargs)
Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 13, in storage_ptr
    return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\FLUX_int8.py", line 14, in <module>
    transformer.save_pretrained(single_model_path, max_shard_size="100GB")
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\modeling_utils.py", line 434, in save_pretrained
    safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 286, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 481, in _flatten
    shared_pointers = _find_shared_tensors(tensors)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 72, in _find_shared_tensors
    if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\safetensors\torch.py", line 17, in storage_ptr
    return tensor.storage().data_ptr()
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\storage.py", line 1224, in data_ptr
    return self._data_ptr()
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\storage.py", line 1228, in _data_ptr
    return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

Posting here as both issues are related to same model and quantization.

i try to load flux pipeline on cpu, quantize the transformer and text_encoder_2. but then error happens in pipe.to('cuda'). because my gpu only get 24GB vram, i can't load pipe to gpu then quantize the model. so i saved the qauntized pipeline, and get the same error above. i use torchao 0.7.0 and diffusers in main branch

@a-r-r-o-w
Copy link
Member

@nitinmukesh For this comment, you need to pass safe_serialization=False, I think, for it to work. This is because there is not safetensors support for saving torchao quantized models yet.

Regarding the sequential offloading issue, I'm opening a PR to accelerate shortly.

@a-r-r-o-w
Copy link
Member

@nitinmukesh Could you try installing accelerate from this branch and seeing if it fixes the inference? It's working for me now

@zhangvia
Copy link

zhangvia commented Jan 9, 2025

@nitinmukesh For this comment, you need to pass safe_serialization=False, I think, for it to work. This is because there is not safetensors support for saving torchao quantized models yet.

i think the safe_serialization=False doesn't help.
image

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jan 9, 2025

@zhangvia Cannot seem to replicate. Could you share the output of diffusers-cli env and try to upgrade your huggingface_hub version? For example, this is mine:

Env
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.26.2
- Transformers version: 4.48.0.dev0
- Accelerate version: 1.1.0.dev0
- PEFT version: 0.13.3.dev0
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
    cache_dir="/raid/.cache/huggingface",
)
transformer.save_pretrained("/raid/aryan/flux-transformer-int8wo", max_shard_size="100GB", safe_serialization=False)

This is the code I'm using for testing. LMK if this should be different

@nitinmukesh
Copy link
Author

Thank you @a-r-r-o-w

I will verify both issues and let you know.

@zhangvia
Copy link

@a-r-r-o-w
here is my diffusers-cli env output:

2025-01-10 01:45:20.889200: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-10 01:45:20.902103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1736473520.920618    5981 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736473520.925624    5981 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 01:45:20.942390: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.2.14-050214-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.0
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.24.3
- Transformers version: 4.46.3
- Accelerate version: 1.2.1
- PEFT version: 0.13.2
- Bitsandbytes version: 0.44.1
- Safetensors version: 0.4.5
- xFormers version: 0.0.28.post3
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

i use the exact script you paste, but i change the model to flux-fill model:

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "/media/74nvme/research/checkpoints/FLUX.1-Fill-dev/"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)
transformer.save_pretrained("/media/74nvme/checkpoints/flux/flux-fp8/transformer", max_shard_size="100GB", safe_serialization=False)

This is the code I'm using for testing. LMK if this should be different

the same error here:

2025-01-10 01:43:12.589355: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-10 01:43:12.835107: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1736473392.952239    5850 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736473392.978978    5850 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 01:43:13.157551: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/torchao/utils.py:434: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return func(*args, **kwargs)
Traceback (most recent call last):
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 406, in storage_ptr
    return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/media/74nvme/research/test.py", line 269, in <module>
    transformer.save_pretrained("/media/74nvme/checkpoints/flux/flux-znwtryon-fp8/transformer", max_shard_size="100GB", safe_serialization=False)
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 406, in save_pretrained
    state_dict_split = split_torch_state_dict_into_shards(
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 330, in split_torch_state_dict_into_shards
    return split_state_dict_into_shards_factory(
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/huggingface_hub/serialization/_base.py", line 108, in split_state_dict_into_shards_factory
    storage_id = get_storage_id(tensor)
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 359, in get_torch_storage_id
    unique_id = storage_ptr(tensor)
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 410, in storage_ptr
    return tensor.storage().data_ptr()
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/torch/storage.py", line 1224, in data_ptr
    return self._data_ptr()
  File "/media/74nvme/software/miniconda3/envs/comfyui/lib/python3.10/site-packages/torch/storage.py", line 1228, in _data_ptr
    return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

@nitinmukesh
Copy link
Author

nitinmukesh commented Jan 10, 2025

@a-r-r-o-w

I have verified inference. It is working now. I will now test saving the quantized model
tried with one at a time works fine.
pipe.enable_sequential_cpu_offload()

OR

pipe.enable_model_cpu_offload()

(venv) C:\ai1\diffuser_t2i>python FLUX_FP8_int8AO.py
Fetching 3 files: 100%|███████████████████████████████████████| 3/3 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|█████████████████████████████| 2/2 [00:00<00:00,  9.12it/s]
Loading pipeline components...:  57%|█████████████████▋             | 4/7 [00:00<00:00,  5.09it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|█████████████████████████| 7/7 [00:01<00:00,  6.53it/s]
100%|█████████████████████████████████████████████| 5/5 [01:30<00:00, 18.19s/it]

Unfortunately even with different combinations of num_inference_steps and guidance_scale the quality is very very bad. Not sure if it has to do with quantization or anything else. Verified using Forge (without quantization) and output with same settings are good.

output0 0

output2 0

output63 0

@nitinmukesh
Copy link
Author

nitinmukesh commented Jan 10, 2025

Test 2: Completed testing of saving quantized model using torchao.
Successful

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)
transformer.save_pretrained("int8", max_shard_size="100GB", safe_serialization=False)

image

Next test to load quantized model and use it directly. Any sample code available on how to load quantized model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants