Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementing flux on TPUs with ptxla #10515

Merged
merged 13 commits into from
Jan 16, 2025
Merged
100 changes: 100 additions & 0 deletions examples/research_projects/pytorch_xla/inference/flux/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Generating images using Flux and PyTorch/XLA

The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.

It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.

## Create TPU

To create a TPU on Google Cloud, follow [this guide](https://cloud.google.com/tpu/docs/v6e)

## Setup TPU environment

SSH into the VM and install Pytorch, Pytorch/XLA

```bash
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

Verify that PyTorch and PyTorch/XLA were installed correctly:

```bash
python3 -c "import torch; import torch_xla;"
```

Install dependencies

```bash
pip install transformers accelerate sentencepiece structlog
pushd ../../..
pip install .
popd
```

## Run the inference job

### Authenticate

Run the following command to authenticate your token in order to download Flux weights.

```bash
huggingface-cli login
```

Then run:

```bash
python flux_inference.py
```

The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.

On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):

```bash
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
2025-01-10 00:51:34 [info ] starting compilation run...
2025-01-10 00:51:35 [info ] starting compilation run...
2025-01-10 00:51:37 [info ] starting compilation run...
2025-01-10 00:51:37 [info ] starting compilation run...
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
2025-01-10 00:52:53 [info ] starting inference run...
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
2025-01-10 00:52:57 [info ] starting inference run...
2025-01-10 00:52:57 [info ] starting inference run...
2025-01-10 00:52:58 [info ] starting inference run...
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from argparse import ArgumentParser
from pathlib import Path
from time import perf_counter

import structlog
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr

from diffusers import FluxPipeline


logger = structlog.get_logger()
metrics_filepath = "/tmp/metrics_report.txt"


def _main(index, args, text_pipe, ckpt_id):
cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp")
cache_path.mkdir(parents=True, exist_ok=True)
xr.initialize_cache(str(cache_path), readonly=False)

profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp")
profile_path.mkdir(parents=True, exist_ok=True)
profiler_port = 9012
profile_duration = args.profile_duration
if args.profile:
logger.info(f"starting profiler on port {profiler_port}")
_ = xp.start_server(profiler_port)
device0 = xm.xla_device()

logger.info(f"loading flux from {ckpt_id}")
flux_pipe = FluxPipeline.from_pretrained(
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
).to(device0)
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)

prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
width = args.width
height = args.height
guidance = args.guidance
n_steps = 4 if args.schnell else 28

logger.info("starting compilation run...")
ts = perf_counter()
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)

image = flux_pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=28,
guidance_scale=guidance,
height=height,
width=width,
).images[0]
logger.info(f"compilation took {perf_counter() - ts} sec.")
image.save("/tmp/compile_out.png")

base_seed = 4096 if args.seed is None else args.seed
seed_range = 1000
unique_seed = base_seed + index * seed_range
xm.set_rng_state(seed=unique_seed, device=device0)
times = []
logger.info("starting inference run...")
for _ in range(args.itters):
ts = perf_counter()
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)

if args.profile:
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
image = flux_pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=n_steps,
guidance_scale=guidance,
height=height,
width=width,
).images[0]
inference_time = perf_counter() - ts
if index == 0:
logger.info(f"inference time: {inference_time}")
times.append(inference_time)
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
image.save(f"/tmp/inference_out-{index}.png")
if index == 0:
metrics_report = met.metrics_report()
with open(metrics_filepath, "w+") as fout:
fout.write(metrics_report)
logger.info(f"saved metric information as {metrics_filepath}")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
parser.add_argument("--guidance", type=float, default=3.5, help="gauidance strentgh for dev")
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
parser.add_argument("--profile", action="store_true", help="enable profiling")
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
parser.add_argument("--itters", type=int, default=15, help="tiems to run inference and get avg time in sec.")
args = parser.parse_args()
if args.schnell:
ckpt_id = "black-forest-labs/FLUX.1-schnell"
else:
ckpt_id = "black-forest-labs/FLUX.1-dev"
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
xmp.spawn(_main, args=(args, text_pipe, ckpt_id))
113 changes: 108 additions & 5 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(
self.set_processor(processor)

def set_use_xla_flash_attention(
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, **kwargs
entrpn marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
r"""
Set whether to use xla flash attention from `torch_xla` or not.
Expand All @@ -316,7 +316,10 @@ def set_use_xla_flash_attention(
elif is_spmd() and is_torch_xla_version("<", "2.4"):
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
else:
processor = XLAFlashAttnProcessor2_0(partition_spec)
if len(kwargs) > 0 and kwargs.get("is_flux", None):
entrpn marked this conversation as resolved.
Show resolved Hide resolved
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
else:
processor = XLAFlashAttnProcessor2_0(partition_spec)
else:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
Expand Down Expand Up @@ -2318,9 +2321,8 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -2522,6 +2524,7 @@ def __call__(
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -3422,6 +3425,106 @@ def __call__(
return hidden_states


class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""

def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
self.partition_spec = partition_spec

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

query /= math.sqrt(head_dim)
hidden_states = flash_attention(query, key, value, causal=False)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states
else:
return hidden_states


class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
Expand Down
Loading