Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlexAttention uses much more GPU memory than FlashAttention-2 #101

Open
ChenlongDeng opened this issue Jan 9, 2025 · 0 comments
Open

FlexAttention uses much more GPU memory than FlashAttention-2 #101

ChenlongDeng opened this issue Jan 9, 2025 · 0 comments

Comments

@ChenlongDeng
Copy link

ChenlongDeng commented Jan 9, 2025

Thank you for the outstanding work on FlexAttention! I am currently trying to integrate FlexAttention with the Hugging Face Transformers framework for training. However, I noticed that FlexAttention seems to consume more GPU memory compared to FlashAttention-2. The issue can be reproduced using the following demo scripts:

Reproduction

You need two files to reproduce my observations, and these two files are in the same folder.

  1. memory_test.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator
import argparse
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from datasets import Dataset
from flex_attention import LlamaFlexAttention

parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-3.1-8B")
parser.add_argument("--attention_type", type=str, default="flex")
parser.add_argument("--train_length", type=int, default=2048)
parser.add_argument("--dataset_size", type=int, default=8192)
args = parser.parse_args()

if __name__ == "__main__":
    assert args.attention_type in ["flash_attention_2", "flex", "sdpa", "eager"], "Invalid attention type"

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    
    if args.attention_type == "flex":
        LLAMA_ATTENTION_CLASSES["flash_attention_2"] = LlamaFlexAttention
        attn_implementation = "flash_attention_2"
    else:
        attn_implementation = args.attention_type

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation)

    random_input_ids = torch.randint(low=0, high=tokenizer.vocab_size, size=(args.dataset_size, args.train_length))

    train_dataset = Dataset.from_dict({"input_ids": random_input_ids.tolist(), "labels": random_input_ids.tolist()})
    
    training_args = TrainingArguments(
        output_dir=f"./tmp-{args.attention_type}",
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=1,
        save_steps=500,
        save_total_limit=1,
        max_steps=500,
        logging_steps=10,
        logging_dir="./logs",
        logging_first_step=True,
        report_to="none",
        do_train=True,
        gradient_checkpointing=True,
        gradient_accumulation_steps=8,
        deepspeed="../../config/deepspeed/stage2-offload.json",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )

    trainer.train()
  1. flex_attention.py
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from transformers.models.llama.modeling_llama import LlamaAttention, StaticCache, apply_rotary_pos_emb, repeat_kv, Cache, logger, DynamicCache, BaseModelOutputWithPast, FlashAttentionKwargs, Unpack, LlamaModel, add_start_docstrings_to_model_forward, LLAMA_INPUTS_DOCSTRING
from typing import Optional, Tuple, Union, List
from functools import lru_cache

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def score_mod(score, b, h, q_idx, kv_idx):
    return score

flex_attention = torch.compile(flex_attention, mode="max-autotune")

@lru_cache
def create_block_mask_cached(mask_mod: Optional[torch.BoolTensor] = None, B: int = 1, H: int = 1, Q_LEN: int = 1, KV_LEN: int = 1, device: Optional[torch.device] = None):
    return torch.compile(create_block_mask(mask_mod=mask_mod, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, BLOCK_SIZE=(128, 64)), mode="max-autotune")

class LlamaFlexAttention(LlamaAttention):
    """
    Llama flex attention module. This module inherits from `LlamaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flex attention and deal with padding tokens in case the input contains any of them.
    """
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
            raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
        # to be able to avoid many of these transpose/reshape/view.
        # query_states = query_states.transpose(1, 2)
        # key_states = key_states.transpose(1, 2)
        # value_states = value_states.transpose(1, 2)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        dropout_rate = self.attention_dropout if self.training else 0.0

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        attn_output = flex_attention(
            query_states,
            key_states,
            value_states,
            block_mask=kwargs["block_mask"] if "block_mask" in kwargs else None,
            score_mod=None if "block_mask" in kwargs else score_mod,
        )

        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def llama_model_forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

    if self.gradient_checkpointing and self.training and use_cache:
        logger.warning_once(
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
        )
        use_cache = False

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    # kept for BC (non `Cache` `past_key_values` inputs)
    return_legacy_cache = False
    if use_cache and not isinstance(past_key_values, Cache):
        return_legacy_cache = True
        if past_key_values is None:
            past_key_values = DynamicCache()
        else:
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
                "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
                "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
            )

    if cache_position is None:
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )
    if position_ids is None:
        position_ids = cache_position.unsqueeze(0)

    causal_mask = self._update_causal_mask(
        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    )
    hidden_states = inputs_embeds

    # create position embeddings to be shared across the decoder layers
    position_embeddings = self.rotary_emb(hidden_states, position_ids)

    # decoder layers
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None
    next_decoder_cache = None
    # block_mask
    if isinstance(self.layers[0].self_attn, LlamaFlexAttention):
        block_mask = create_block_mask_cached(mask_mod=causal_mask, B=1, H=1, Q_LEN=hidden_states.size(1), KV_LEN=hidden_states.size(1), device=hidden_states.device)
        flash_attn_kwargs["block_mask"] = block_mask

    for decoder_layer in self.layers[: self.config.num_hidden_layers]:
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if self.gradient_checkpointing and self.training:
            layer_outputs = self._gradient_checkpointing_func(
                decoder_layer.__call__,
                hidden_states,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
                cache_position,
                position_embeddings,
                **flash_attn_kwargs,
            )
        else:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    hidden_states = self.norm(hidden_states)

    # add hidden states from the last decoder layer
    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    next_cache = next_decoder_cache if use_cache else None
    if return_legacy_cache:
        next_cache = next_cache.to_legacy_cache()

    if not return_dict:
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )

Usage

torchrun --nproc_per_node=8 memory_test.py --attention_type flex # FlexAttention
torchrun --nproc_per_node=8 memory_test.py --attention_type flash_attention_2 # FlashAttention-2

Environments

torch==2.6.0.dev20241218+cu118
transformers==4.47.1

My GPU devices are 8*A100-40G.

Observations

I have noticed that FlexAttention uses approximately 39GB of GPU memory across 8 devices, whereas FlashAttention-2 requires only around 36GB. I'm currently unsure whether this discrepancy arises from the internal implementation of FlexAttention or the block mask. Changing the block mask to score_mod did not resolve the issue either.

I would appreciate any insights or explanations regarding this matter! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant