You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for the outstanding work on FlexAttention! I am currently trying to integrate FlexAttention with the Hugging Face Transformers framework for training. However, I noticed that FlexAttention seems to consume more GPU memory compared to FlashAttention-2. The issue can be reproduced using the following demo scripts:
Reproduction
You need two files to reproduce my observations, and these two files are in the same folder.
importtorchfromtorch.nn.attention.flex_attentionimportflex_attention, create_block_maskfromtransformers.models.llama.modeling_llamaimportLlamaAttention, StaticCache, apply_rotary_pos_emb, repeat_kv, Cache, logger, DynamicCache, BaseModelOutputWithPast, FlashAttentionKwargs, Unpack, LlamaModel, add_start_docstrings_to_model_forward, LLAMA_INPUTS_DOCSTRINGfromtypingimportOptional, Tuple, Union, Listfromfunctoolsimportlru_cachedefcausal_mask(b, h, q_idx, kv_idx):
returnq_idx>=kv_idxdefscore_mod(score, b, h, q_idx, kv_idx):
returnscoreflex_attention=torch.compile(flex_attention, mode="max-autotune")
@lru_cachedefcreate_block_mask_cached(mask_mod: Optional[torch.BoolTensor] =None, B: int=1, H: int=1, Q_LEN: int=1, KV_LEN: int=1, device: Optional[torch.device] =None):
returntorch.compile(create_block_mask(mask_mod=mask_mod, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, BLOCK_SIZE=(128, 64)), mode="max-autotune")
classLlamaFlexAttention(LlamaAttention):
""" Llama flex attention module. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flex attention and deal with padding tokens in case the input contains any of them. """defforward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] =None,
position_ids: Optional[torch.LongTensor] =None,
past_key_value: Optional[Cache] =None,
output_attentions: bool=False,
use_cache: bool=False,
cache_position: Optional[torch.LongTensor] =None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] =None, # will become mandatory in v4.45**kwargs,
) ->Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
ifisinstance(past_key_value, StaticCache):
raiseValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` ""make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions=Falsebsz, q_len, _=hidden_states.size()
query_states=self.q_proj(hidden_states)
key_states=self.k_proj(hidden_states)
value_states=self.v_proj(hidden_states)
# Flash attention requires the input to have the shape# batch_size x seq_length x head_dim x hidden_dim# therefore we just need to keep the original shapequery_states=query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states=key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states=value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
ifposition_embeddingsisNone:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally ""through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ""`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be ""removed and `position_embeddings` will be mandatory."
)
cos, sin=self.rotary_emb(value_states, position_ids)
else:
cos, sin=position_embeddingsquery_states, key_states=apply_rotary_pos_emb(query_states, key_states, cos, sin)
ifpast_key_valueisnotNone:
# sin and cos are specific to RoPE models; cache_position needed for the static cachecache_kwargs= {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states=past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache# to be able to avoid many of these transpose/reshape/view.# query_states = query_states.transpose(1, 2)# key_states = key_states.transpose(1, 2)# value_states = value_states.transpose(1, 2)key_states=repeat_kv(key_states, self.num_key_value_groups)
value_states=repeat_kv(value_states, self.num_key_value_groups)
dropout_rate=self.attention_dropoutifself.trainingelse0.0# In PEFT, usually we cast the layer norms in float32 for training stability reasons# therefore the input hidden states gets silently casted in float32. Hence, we need# cast them back in the correct dtype just to be sure everything works as expected.# This might slowdown training & inference so it is recommended to not cast the LayerNorms# in fp32. (LlamaRMSNorm handles it correctly)input_dtype=query_states.dtypeifinput_dtype==torch.float32:
iftorch.is_autocast_enabled():
target_dtype=torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantizedelifhasattr(self.config, "_pre_quantization_dtype"):
target_dtype=self.config._pre_quantization_dtypeelse:
target_dtype=self.q_proj.weight.dtypelogger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"f" {target_dtype}."
)
query_states=query_states.to(target_dtype)
key_states=key_states.to(target_dtype)
value_states=value_states.to(target_dtype)
attn_output=flex_attention(
query_states,
key_states,
value_states,
block_mask=kwargs["block_mask"] if"block_mask"inkwargselseNone,
score_mod=Noneif"block_mask"inkwargselsescore_mod,
)
attn_output=attn_output.transpose(1, 2).reshape(bsz, q_len, -1).contiguous()
attn_output=self.o_proj(attn_output)
ifnotoutput_attentions:
attn_weights=Nonereturnattn_output, attn_weights, past_key_value@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)defllama_model_forward(
self,
input_ids: torch.LongTensor=None,
attention_mask: Optional[torch.Tensor] =None,
position_ids: Optional[torch.LongTensor] =None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] =None,
inputs_embeds: Optional[torch.FloatTensor] =None,
use_cache: Optional[bool] =None,
output_attentions: Optional[bool] =None,
output_hidden_states: Optional[bool] =None,
return_dict: Optional[bool] =None,
cache_position: Optional[torch.LongTensor] =None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) ->Union[Tuple, BaseModelOutputWithPast]:
output_attentions=output_attentionsifoutput_attentionsisnotNoneelseself.config.output_attentionsoutput_hidden_states= (
output_hidden_statesifoutput_hidden_statesisnotNoneelseself.config.output_hidden_states
)
use_cache=use_cacheifuse_cacheisnotNoneelseself.config.use_cachereturn_dict=return_dictifreturn_dictisnotNoneelseself.config.use_return_dictif (input_idsisNone) ^ (inputs_embedsisnotNone):
raiseValueError("You must specify exactly one of input_ids or inputs_embeds")
ifself.gradient_checkpointingandself.traininganduse_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache=Falseifinputs_embedsisNone:
inputs_embeds=self.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)return_legacy_cache=Falseifuse_cacheandnotisinstance(past_key_values, Cache):
return_legacy_cache=Trueifpast_key_valuesisNone:
past_key_values=DynamicCache()
else:
past_key_values=DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and ""will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class ""(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
ifcache_positionisNone:
past_seen_tokens=past_key_values.get_seq_length() ifpast_key_valuesisnotNoneelse0cache_position=torch.arange(
past_seen_tokens, past_seen_tokens+inputs_embeds.shape[1], device=inputs_embeds.device
)
ifposition_idsisNone:
position_ids=cache_position.unsqueeze(0)
causal_mask=self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states=inputs_embeds# create position embeddings to be shared across the decoder layersposition_embeddings=self.rotary_emb(hidden_states, position_ids)
# decoder layersall_hidden_states= () ifoutput_hidden_stateselseNoneall_self_attns= () ifoutput_attentionselseNonenext_decoder_cache=None# block_maskifisinstance(self.layers[0].self_attn, LlamaFlexAttention):
block_mask=create_block_mask_cached(mask_mod=causal_mask, B=1, H=1, Q_LEN=hidden_states.size(1), KV_LEN=hidden_states.size(1), device=hidden_states.device)
flash_attn_kwargs["block_mask"] =block_maskfordecoder_layerinself.layers[: self.config.num_hidden_layers]:
ifoutput_hidden_states:
all_hidden_states+= (hidden_states,)
ifself.gradient_checkpointingandself.training:
layer_outputs=self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
**flash_attn_kwargs,
)
else:
layer_outputs=decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states=layer_outputs[0]
ifuse_cache:
next_decoder_cache=layer_outputs[2ifoutput_attentionselse1]
ifoutput_attentions:
all_self_attns+= (layer_outputs[1],)
hidden_states=self.norm(hidden_states)
# add hidden states from the last decoder layerifoutput_hidden_states:
all_hidden_states+= (hidden_states,)
next_cache=next_decoder_cacheifuse_cacheelseNoneifreturn_legacy_cache:
next_cache=next_cache.to_legacy_cache()
ifnotreturn_dict:
returntuple(vforvin [hidden_states, next_cache, all_hidden_states, all_self_attns] ifvisnotNone)
returnBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
I have noticed that FlexAttention uses approximately 39GB of GPU memory across 8 devices, whereas FlashAttention-2 requires only around 36GB. I'm currently unsure whether this discrepancy arises from the internal implementation of FlexAttention or the block mask. Changing the block mask to score_mod did not resolve the issue either.
I would appreciate any insights or explanations regarding this matter! Thank you!
The text was updated successfully, but these errors were encountered:
Thank you for the outstanding work on FlexAttention! I am currently trying to integrate FlexAttention with the Hugging Face Transformers framework for training. However, I noticed that FlexAttention seems to consume more GPU memory compared to FlashAttention-2. The issue can be reproduced using the following demo scripts:
Reproduction
You need two files to reproduce my observations, and these two files are in the same folder.
Usage
Environments
My GPU devices are 8*A100-40G.
Observations
I have noticed that FlexAttention uses approximately 39GB of GPU memory across 8 devices, whereas FlashAttention-2 requires only around 36GB. I'm currently unsure whether this discrepancy arises from the internal implementation of FlexAttention or the block mask. Changing the block mask to score_mod did not resolve the issue either.
I would appreciate any insights or explanations regarding this matter! Thank you!
The text was updated successfully, but these errors were encountered: