From 626e339b0228eb8ea0583fe9f49c5f61cfad63bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 30 Dec 2024 11:28:51 +0800 Subject: [PATCH 01/18] NPU Adaption for Sanna --- examples/dreambooth/train_dreambooth_lora_sana.py | 11 +++++++---- src/diffusers/models/attention_processor.py | 11 ++++++++++- src/diffusers/models/transformers/sana_transformer.py | 11 +++++++++-- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 7bec9c799cae..795b2f7f73e0 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -63,6 +63,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -74,6 +75,9 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + def save_model_card( repo_id: str, @@ -920,8 +924,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -979,10 +982,10 @@ def main(args): ) # VAE should always be kept in fp32 for SANA (?) - vae.to(dtype=torch.float32) + vae.to(accelerator.device, dtype=torch.float32) transformer.to(accelerator.device, dtype=weight_dtype) # because Gemma2 is particularly suited for bfloat16. - text_encoder.to(dtype=torch.bfloat16) + text_encoder.to(accelerator.device, dtype=torch.bfloat16) # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = SanaPipeline.from_pretrained( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef26e..3ab5609ceeac 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3147,7 +3147,16 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attn_mask = attention_mask[0] + seq_len = hidden_states.shape[1] + attention_mask = attn_mask.repeat_interleave(seq_len * batch_size, dim=0) + attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1]) + + if attention_mask.dtype != torch.uint8: + if attention_mask.dtype == torch.bool: + attention_mask = torch.logical_not(attention_mask.bool()) + else: + attention_mask = attention_mask.to(torch.uint8) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 027ab5fecefd..4d30162d0f56 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,11 +19,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available from ..attention_processor import ( Attention, AttentionProcessor, AttnProcessor2_0, + AttnProcessorNPU, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -119,6 +120,12 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + if is_torch_npu_available(): + attn_processor = AttnProcessorNPU() + else: + attn_processor = AttnProcessor2_0() + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, @@ -127,7 +134,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=AttnProcessor2_0(), + processor=attn_processor, ) # 3. Feed-forward From 1a72a00765d148072117c5af984e3ad96c27e4ef Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Mon, 30 Dec 2024 14:15:14 +0800 Subject: [PATCH 02/18] NPU Adaption for Sanna --- examples/dreambooth/train_dreambooth_lora_sana.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 795b2f7f73e0..98509ebd3e44 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -982,10 +982,10 @@ def main(args): ) # VAE should always be kept in fp32 for SANA (?) - vae.to(accelerator.device, dtype=torch.float32) + vae.to(dtype=torch.float32) transformer.to(accelerator.device, dtype=weight_dtype) # because Gemma2 is particularly suited for bfloat16. - text_encoder.to(accelerator.device, dtype=torch.bfloat16) + text_encoder.to(dtype=torch.bfloat16) # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = SanaPipeline.from_pretrained( From a1965dd697b83434043699a6b1447600737f626e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Tue, 7 Jan 2025 10:16:52 +0800 Subject: [PATCH 03/18] NPU Adaption for Sanna --- examples/dreambooth/train_dreambooth_lora_sana.py | 10 ++++++++++ src/diffusers/models/transformers/sana_transformer.py | 11 ++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 98509ebd3e44..ab194dad0ca7 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -601,6 +601,9 @@ def parse_args(input_args=None): help="Whether to offload the VAE and the text encoder to CPU when they are not used.", ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) if input_args is not None: args = parser.parse_args(input_args) @@ -967,6 +970,13 @@ def main(args): vae.requires_grad_(False) text_encoder.requires_grad_(False) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 4d30162d0f56..8c01849b2205 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,12 +19,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, AttnProcessor2_0, - AttnProcessorNPU, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -120,12 +119,6 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - if is_torch_npu_available(): - attn_processor = AttnProcessorNPU() - else: - attn_processor = AttnProcessor2_0() - self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, @@ -134,7 +127,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=attn_processor, + processor=AttnProcessor2_0, ) # 3. Feed-forward From 326b98d8a94b8f56043bf24ea3cd6ac76b610023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Tue, 7 Jan 2025 10:18:40 +0800 Subject: [PATCH 04/18] NPU Adaption for Sanna --- src/diffusers/models/transformers/sana_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 47eb81808825..bc3877627529 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -127,7 +127,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=AttnProcessor2_0, + processor=AttnProcessor2_0(), ) # 3. Feed-forward From 963e29068dff7768e2cd546ed58f2ed9368b7e22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Tue, 7 Jan 2025 20:39:27 +0800 Subject: [PATCH 05/18] NPU Adaption for Sanna --- examples/dreambooth/train_dreambooth_lora_sana.py | 10 ---------- .../models/transformers/sana_transformer.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index ab194dad0ca7..98509ebd3e44 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -601,9 +601,6 @@ def parse_args(input_args=None): help="Whether to offload the VAE and the text encoder to CPU when they are not used.", ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument( - "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." - ) if input_args is not None: args = parser.parse_args(input_args) @@ -970,13 +967,6 @@ def main(args): vae.requires_grad_(False) text_encoder.requires_grad_(False) - if args.enable_npu_flash_attention: - if is_torch_npu_available(): - logger.info("npu flash attention enabled.") - transformer.enable_npu_flash_attention() - else: - raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index bc3877627529..4d30162d0f56 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,11 +19,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available from ..attention_processor import ( Attention, AttentionProcessor, AttnProcessor2_0, + AttnProcessorNPU, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -119,6 +120,12 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + if is_torch_npu_available(): + attn_processor = AttnProcessorNPU() + else: + attn_processor = AttnProcessor2_0() + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, @@ -127,7 +134,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=AttnProcessor2_0(), + processor=attn_processor, ) # 3. Feed-forward @@ -250,6 +257,7 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim # 1. Patch Embedding + interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1) self.patch_embed = PatchEmbed( height=sample_size, width=sample_size, @@ -257,7 +265,6 @@ def __init__( in_channels=in_channels, embed_dim=inner_dim, interpolation_scale=interpolation_scale, - pos_embed_type="sincos" if interpolation_scale is not None else None, ) # 2. Additional condition embeddings From 3d3aae3fb2d27d38c16d0d0eb55b63b73d05c5c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Wed, 8 Jan 2025 08:59:59 +0800 Subject: [PATCH 06/18] NPU Adaption for Sanna --- src/diffusers/models/transformers/sana_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 4d30162d0f56..ef3b43d99e78 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -257,7 +257,6 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim # 1. Patch Embedding - interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1) self.patch_embed = PatchEmbed( height=sample_size, width=sample_size, @@ -265,6 +264,7 @@ def __init__( in_channels=in_channels, embed_dim=inner_dim, interpolation_scale=interpolation_scale, + pos_embed_type="sincos" if interpolation_scale is not None else None, ) # 2. Additional condition embeddings From 4cea81971347b2e89ad1c8cdc45ff7a3de5c1496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Wed, 8 Jan 2025 11:17:16 +0800 Subject: [PATCH 07/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 4 ++++ src/diffusers/models/transformers/sana_transformer.py | 8 +------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3ab5609ceeac..c621823e321e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -294,6 +294,10 @@ def __init__( processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) + + if is_torch_npu_available(): + if isinstance(processor, AttnProcessor2_0): + processor = AttnProcessorNPU() self.set_processor(processor) def set_use_xla_flash_attention( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index ef3b43d99e78..0c91753edfdf 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -120,12 +120,6 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - if is_torch_npu_available(): - attn_processor = AttnProcessorNPU() - else: - attn_processor = AttnProcessor2_0() - self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, @@ -134,7 +128,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=attn_processor, + processor=AttnProcessor2_0(), ) # 3. Feed-forward From 0d9e1b3aa81535eb92549b1651ba3925a0028bd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Wed, 8 Jan 2025 11:22:37 +0800 Subject: [PATCH 08/18] NPU Adaption for Sanna --- src/diffusers/models/transformers/sana_transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 0c91753edfdf..bc3877627529 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,12 +19,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, AttnProcessor2_0, - AttnProcessorNPU, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection From 20520496ef7a21f7bb11cfd692f0ec620d1c326a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Wed, 8 Jan 2025 11:47:42 +0800 Subject: [PATCH 09/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c621823e321e..b22c4c5af70f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -294,10 +294,6 @@ def __init__( processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) - - if is_torch_npu_available(): - if isinstance(processor, AttnProcessor2_0): - processor = AttnProcessorNPU() self.set_processor(processor) def set_use_xla_flash_attention( @@ -525,6 +521,11 @@ def set_processor(self, processor: "AttnProcessor") -> None: processor (`AttnProcessor`): The attention processor to use. """ + # set to use npu flash attention from 'torch_npu' if available + if is_torch_npu_available(): + if isinstance(processor, AttnProcessor2_0): + processor = AttnProcessorNPU() + # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( From cfbbb8fd11ca9c046020f410d071588f6a6e3c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Tue, 14 Jan 2025 09:06:18 +0800 Subject: [PATCH 10/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 5 ----- .../models/transformers/sana_transformer.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b22c4c5af70f..3ab5609ceeac 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -521,11 +521,6 @@ def set_processor(self, processor: "AttnProcessor") -> None: processor (`AttnProcessor`): The attention processor to use. """ - # set to use npu flash attention from 'torch_npu' if available - if is_torch_npu_available(): - if isinstance(processor, AttnProcessor2_0): - processor = AttnProcessorNPU() - # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index bc3877627529..ae6b76a0edea 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,11 +19,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, AttnProcessor2_0, + AttnProcessorNPU, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -119,6 +120,13 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + # if NPU is available, will use NPU fused attention instead + if is_torch_npu_available(): + attn_processor = AttnProcessorNPU() + else: + attn_processor = AttnProcessor2_0() + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, @@ -127,7 +135,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=AttnProcessor2_0(), + processor=attn_processor, ) # 3. Feed-forward From 4c1d56dc9c7c33e898418f48b64cedbf3f7f448a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 23 Jan 2025 19:43:24 +0800 Subject: [PATCH 11/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 5 +++++ src/diffusers/models/transformers/sana_transformer.py | 9 +-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3ab5609ceeac..398d4d715b0a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -521,6 +521,11 @@ def set_processor(self, processor: "AttnProcessor") -> None: processor (`AttnProcessor`): The attention processor to use. """ + # Set AttnProcessor to NPU if available + if is_torch_npu_available(): + if isinstance(processor, AttnProcessor2_0): + processor = AttnProcessorNPU() + # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index ae6b76a0edea..d3b686a2dfd8 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -120,13 +120,6 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - # if NPU is available, will use NPU fused attention instead - if is_torch_npu_available(): - attn_processor = AttnProcessorNPU() - else: - attn_processor = AttnProcessor2_0() - self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, @@ -135,7 +128,7 @@ def __init__( dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=attn_processor, + processor=AttnProcessor2_0(), ) # 3. Feed-forward From d61d570f4402c038978c5c288dc37dcf503c0f41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 23 Jan 2025 19:44:50 +0800 Subject: [PATCH 12/18] NPU Adaption for Sanna --- src/diffusers/models/transformers/sana_transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 54d0aa356a0f..3dac0d5dc7bf 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,12 +19,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, AttnProcessor2_0, - AttnProcessorNPU, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection From ab2d71b2a9e20f7811f1875749b17b90e533fd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 23 Jan 2025 19:51:42 +0800 Subject: [PATCH 13/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ef0b068ab6a9..f75b45ce7622 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3158,10 +3158,12 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attn_mask = attention_mask[0] - seq_len = hidden_states.shape[1] - attention_mask = attn_mask.repeat_interleave(seq_len * batch_size, dim=0) attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.repeat_interleave(hidden_states.shape[1], dim=2) + if attention_mask.dtype == torch.bool: + attention_mask = torch.logical_not(attention_mask.bool()) + else: + attention_mask = attention_mask.bool() if attention_mask.dtype != torch.uint8: if attention_mask.dtype == torch.bool: From a456fb1de1543f63124a23c4d0390cb1d6900c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 23 Jan 2025 20:00:56 +0800 Subject: [PATCH 14/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f75b45ce7622..6eb5b89287d4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3158,7 +3158,7 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.repeat_interleave(hidden_states.shape[1], dim=2) if attention_mask.dtype == torch.bool: attention_mask = torch.logical_not(attention_mask.bool()) From fedfdd47e70a343013de575efbf8cb0de8ac4b96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Fri, 24 Jan 2025 08:51:05 +0800 Subject: [PATCH 15/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6eb5b89287d4..44b2153708d2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3165,12 +3165,6 @@ def __call__( else: attention_mask = attention_mask.bool() - if attention_mask.dtype != torch.uint8: - if attention_mask.dtype == torch.bool: - attention_mask = torch.logical_not(attention_mask.bool()) - else: - attention_mask = attention_mask.to(torch.uint8) - if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) From 3add6deb906ed45e0f20ef1287273b25b84c74ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Fri, 24 Jan 2025 09:55:26 +0800 Subject: [PATCH 16/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 44b2153708d2..9b74000b7a46 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3159,7 +3159,7 @@ def __call__( # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - attention_mask = attention_mask.repeat_interleave(hidden_states.shape[1], dim=2) + attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1) if attention_mask.dtype == torch.bool: attention_mask = torch.logical_not(attention_mask.bool()) else: From 70cf52954171eeba397cb509ac2174df664c4c03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Fri, 24 Jan 2025 14:37:30 +0800 Subject: [PATCH 17/18] NPU Adaption for Sanna --- examples/dreambooth/train_dreambooth_lora_sana.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 1fc39751d5f8..9e69bd6a668b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -605,6 +605,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") if input_args is not None: args = parser.parse_args(input_args) @@ -991,6 +992,13 @@ def main(args): # because Gemma2 is particularly suited for bfloat16. text_encoder.to(dtype=torch.bfloat16) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = SanaPipeline.from_pretrained( args.pretrained_model_name_or_path, From 8f18aae3e62a6767039d84d4dd81a4a00a57196d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Fri, 24 Jan 2025 14:38:09 +0800 Subject: [PATCH 18/18] NPU Adaption for Sanna --- src/diffusers/models/attention_processor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9b74000b7a46..26625753e4b6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -527,11 +527,6 @@ def set_processor(self, processor: "AttnProcessor") -> None: processor (`AttnProcessor`): The attention processor to use. """ - # Set AttnProcessor to NPU if available - if is_torch_npu_available(): - if isinstance(processor, AttnProcessor2_0): - processor = AttnProcessorNPU() - # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if (