Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A Bug in flash_decoding_chunkllama.py #24

Open
ZayIsAllYouNeed opened this issue Jul 19, 2024 · 1 comment
Open

A Bug in flash_decoding_chunkllama.py #24

ZayIsAllYouNeed opened this issue Jul 19, 2024 · 1 comment

Comments

@ZayIsAllYouNeed
Copy link

ZayIsAllYouNeed commented Jul 19, 2024

Have not from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa

flash_decoding_chunkllama.py:

  • Ln 510: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  • Ln7: from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
@ChenxinAn-fdu
Copy link
Contributor

ChenxinAn-fdu commented Jul 21, 2024

Hi thank you for pointing this but we usually use flash-attention in long-context settings. If you set attn_implementation="flash_attention_2",, the error should not happen.

code

    if self._use_flash_attention_2:
        # 2d mask is passed through the layers
        attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
    elif self._use_sdpa and not output_attentions:
        # output_attentions=True can not be supported when using SDPA, and we fall back on
        # the manual implementation that requires a 4D causal mask in all cases.
        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
            attention_mask,
            (batch_size, seq_length),
            inputs_embeds,
            past_key_values_length,
        )
    else:
        # 4d mask is passed through the layers
        attention_mask = _prepare_4d_causal_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants