Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update modular_modernbert -- add inputs_embeds param to ModernBertModel #35373

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,17 @@ def __init__(self, config: ModernBertConfig):
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
return self.drop(self.norm(self.tok_embeddings(input_ids)))

def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
def forward(
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(inputs_embeds))
else:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
return hidden_states


Expand Down Expand Up @@ -794,6 +799,10 @@ def _pad_modernbert_output(
config.n_positions - 1]`.

[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Expand Down Expand Up @@ -845,10 +854,11 @@ def set_input_embeddings(self, value):
)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
Expand All @@ -864,35 +874,49 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

self._maybe_set_compile()
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)

if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)

if batch_size is None and seq_len is None:
batch_size, seq_len = input_ids.shape[:2]
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)

repad = False
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)

attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)

hidden_states = self.embeddings(input_ids)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)

for encoder_layer in self.layers:
if output_hidden_states:
Expand Down Expand Up @@ -1028,10 +1052,11 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
Expand All @@ -1048,19 +1073,32 @@ def forward(

if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
batch_size, seq_len = input_ids.shape[:2]
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)

if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)

outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
Expand Down Expand Up @@ -1133,10 +1171,11 @@ def __init__(self, config: ModernBertConfig):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
Expand All @@ -1158,10 +1197,11 @@ def forward(
self._maybe_set_compile()

outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
Expand Down Expand Up @@ -1244,10 +1284,11 @@ def __init__(self, config: ModernBertConfig):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
Expand All @@ -1266,10 +1307,11 @@ def forward(
self._maybe_set_compile()

outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
Expand Down
Loading