-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NPU Adaption for Sanna #10409
NPU Adaption for Sanna #10409
Changes from 17 commits
626e339
1a72a00
4d67c54
a1965dd
2c3b117
326b98d
715822f
963e290
510e1d6
3d3aae3
4cea819
0d9e1b3
2052049
487dd1a
cfbbb8f
7b8ad74
d7d54d8
ad4beaa
52d8c71
4c1d56d
63e3459
d61d570
ab2d71b
a323229
a456fb1
fedfdd4
7364276
3add6de
70cf529
8f18aae
feb8064
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,11 +19,12 @@ | |
|
||
from ...configuration_utils import ConfigMixin, register_to_config | ||
from ...loaders import PeftAdapterMixin | ||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers | ||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, is_torch_version, logging, scale_lora_layers, unscale_lora_layers | ||
from ..attention_processor import ( | ||
Attention, | ||
AttentionProcessor, | ||
AttnProcessor2_0, | ||
AttnProcessorNPU, | ||
SanaLinearAttnProcessor2_0, | ||
) | ||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection | ||
|
@@ -119,6 +120,13 @@ def __init__( | |
# 2. Cross Attention | ||
if cross_attention_dim is not None: | ||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) | ||
|
||
# if NPU is available, will use NPU fused attention instead | ||
if is_torch_npu_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as in the other PR - let's not update default attn processor logic for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I've updated the new one, please take a look. This can just use set up NPU FA directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will let you know when the full test is complete There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yiyixuxu It still needs to modify the sanna_transformer file, so I think to check in the init it;s the best option now |
||
attn_processor = AttnProcessorNPU() | ||
else: | ||
attn_processor = AttnProcessor2_0() | ||
|
||
self.attn2 = Attention( | ||
query_dim=dim, | ||
cross_attention_dim=cross_attention_dim, | ||
|
@@ -127,7 +135,7 @@ def __init__( | |
dropout=dropout, | ||
bias=True, | ||
out_bias=attention_out_bias, | ||
processor=AttnProcessor2_0(), | ||
processor=attn_processor, | ||
) | ||
|
||
# 3. Feed-forward | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lawrence-cj let me know if it's ok with you to default to NPU attention when it's available:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I'm not familiar with NPU training and inference. Is this NPU device very popular in diffusers community?