Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Upgrade Transformers to 4.48.x #782

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 1916 files
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"timeout-decorator",
"torch",
"torchvision",
"transformers~=4.47.1",
"transformers~=4.48.0",
]


Expand Down
169 changes: 54 additions & 115 deletions src/adapters/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import torch
import torch.utils.checkpoint

from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2SdpaAttention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
Expand All @@ -41,6 +42,7 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
Expand All @@ -49,37 +51,72 @@ def forward(
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
query_states = self.q_attn(hidden_states)
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)

query_states = query_states.view(shape_q).transpose(1, 2)
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
key_states = torch.cat((past_key, key_states), dim=-2)
value_states = torch.cat((past_value, value_states), dim=-2)

if use_cache is True:
present = (key, value)
present = (key_states, value_states)
else:
present = None

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
# >>> END AH Changes <<<

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
is_cross_attention = encoder_hidden_states is not None
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention

using_eager = self.config._attn_implementation == "eager"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
using_eager = True
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
# not necessarily to eager (if mentionned options are provided).
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

if using_eager and self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query_states, key_states, value_states, attention_mask, head_mask
)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
head_mask=head_mask,
dropout=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
**kwargs,
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

Expand All @@ -90,104 +127,6 @@ def forward(
return outputs # a, present, (attentions)


class GPT2SdpaAttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2SdpaAttention):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

bsz, q_len, _ = hidden_states.size()

# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

present = None
if use_cache is True:
present = (key, value)

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
bsz = key.shape[0]
# >>> END AH Changes <<<

# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)

# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)

# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

return attn_output, present, None


class GPT2BlockWithAdapters(GPT2DecoderBlockAdaptersMixin, GPT2Block):
def forward(
self,
Expand Down
Loading