diff --git a/hf_transformers b/hf_transformers index 241c04d368..6bc0fbcfa7 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc +Subproject commit 6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec diff --git a/setup.py b/setup.py index d7a15ef921..739f5e9653 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.47.1", + "transformers~=4.48.0", ] diff --git a/src/adapters/models/gpt2/modeling_gpt2.py b/src/adapters/models/gpt2/modeling_gpt2.py index bb6410f838..6555250c0b 100644 --- a/src/adapters/models/gpt2/modeling_gpt2.py +++ b/src/adapters/models/gpt2/modeling_gpt2.py @@ -15,12 +15,13 @@ # limitations under the License. """PyTorch OpenAI GPT-2 model.""" -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2SdpaAttention +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward from transformers.utils import logging from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ @@ -41,6 +42,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -49,37 +51,72 @@ def forward( "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_states = self.q_attn(hidden_states) + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + shape_q = (*query_states.shape[:-1], -1, self.head_dim) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + + query_states = query_states.view(shape_q).transpose(1, 2) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + key_states = torch.cat((past_key, key_states), dim=-2) + value_states = torch.cat((past_value, value_states), dim=-2) if use_cache is True: - present = (key, value) + present = (key_states, value_states) else: present = None # >>> START AH Changes <<< - key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask) - (query,) = adjust_tensors_for_parallel(key, query) + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) # >>> END AH Changes <<< - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + is_cross_attention = encoder_hidden_states is not None + is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention + + using_eager = self.config._attn_implementation == "eager" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): + using_eager = True + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + # Attention functions are consistent with previous equivalent attention classes, however they do not support some options + # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but + # not necessarily to eager (if mentionned options are provided). + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if using_eager and self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query_states, key_states, value_states, attention_mask, head_mask + ) else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + head_mask=head_mask, + dropout=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, + **kwargs, + ) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -90,104 +127,6 @@ def forward( return outputs # a, present, (attentions) -class GPT2SdpaAttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2SdpaAttention): - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - bsz, q_len, _ = hidden_states.size() - - # Initial attention projections - is_cross_attention = encoder_hidden_states is not None - if is_cross_attention: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - # Optional kv caching - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - present = None - if use_cache is True: - present = (key, value) - - # >>> START AH Changes <<< - key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask) - (query,) = adjust_tensors_for_parallel(key, query) - bsz = key.shape[0] - # >>> END AH Changes <<< - - # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA - if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.attn_dropout.p if self.training else 0.0, - is_causal=is_causal, - ) - - # Reshape outputs - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.embed_dim) - - # Final projection - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - return attn_output, present, None - - class GPT2BlockWithAdapters(GPT2DecoderBlockAdaptersMixin, GPT2Block): def forward( self,