Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update modular_modernbert -- add inputs_embeds param to ModernBertModel #35373

Merged
merged 5 commits into from
Jan 9, 2025

Conversation

jxmorris12
Copy link
Contributor

What does this PR do?

Hi! Congrats on the release of ModernBERT; it looks amazing. I'm interested in using ModernBERT eventually to train a new Contextual Document Embeddings model.

One desired feature is to pass the contextual and word embeddings together in the second stage, which requires setting the inputs_embeds kwarg so that we can pass hidden states directly. This is a feature of typical BERT and other transformer implementations but isn't yet allowed by ModernBERT, so I added it. It's only a few additional lines of code.

cc: @warner-benjamin @tomaarsen @orionw @staghado @bclavie @NohTow @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yep would be nice to add. let's make it a bit simpler (si how llama has less if elses) and add a small test ! 🤗

@tomaarsen tomaarsen self-requested a review December 28, 2024 12:18
First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented.

I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes.

I also introduced an error if input_ids and input_embeds are both or neither provided.

Lastly, I fixed an issue with device being based solely on input_ids with attention_mask.
@tomaarsen
Copy link
Member

Hello @jxmorris12, @ArthurZucker,

I pushed some changes into this PR to get it closer to completion. Let me know if you're not okay with this, and you can easily revert or delete the commits.

The changes:

  1. The inputs_embeds shouldn't fully replace self.embeddings(input_ids), because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented.
  2. I also added inputs_embeds to the docstring, and propagated the changes to the other model classes.
  3. I also introduced an error if input_ids and input_embeds are both or neither provided.
  4. I fixed an issue with device being based solely on input_ids with attention_mask.
  5. Fix a test, and reintroduce another test.

let's make it a bit simpler (si how llama has less if elses)

This is sadly not as simple as it seems due to _unpad_modernbert_input.

  • Tom Aarsen

Comment on lines 1227 to 1228
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit frustrating that this entire tree is necessary, but there's no other convenient way to avoid the base model from repadding while allowing this class to repad, as then the base model would also have to return the batch size, indices, seqlens, etc. so this class could repad.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH it might make more sense! This way only the base model unpads, and the other models can freely do the unpadding. Kinda up to you 🤗

@NohTow
Copy link
Contributor

NohTow commented Dec 30, 2024

Hello,

Thanks for opening this PR, I really love your work on CDE and I look forward to seeing how it will perform with ModernBERT!
FYI, this feature (being able to pass input_embeds) has also been asked for other use cases here, so besides retro-compatibility and CDE, it seems like this feature is used by the community and is thus a very cool addition.

I'll make a more thorough review when I come back from vacation this week, but I already checked how to implement it for the linked issue and it seems to be in line with the latest change from Tom.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jxmorris12
Copy link
Contributor Author

Hello @jxmorris12, @ArthurZucker,

I pushed some changes into this PR to get it closer to completion. Let me know if you're not okay with this, and you can easily revert or delete the commits.
...

thanks Tom! Looks great.

@jxmorris12
Copy link
Contributor Author

Hi @NohTow @ArthurZucker just following up on this -- thanks :)

@znsoftm
Copy link

znsoftm commented Jan 7, 2025

when will it be merged into the main branch?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating, let's remove complexity and good to go!

Comment on lines +1042 to +1045
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you first embed input ids you don't need two branches!

Comment on lines +1056 to 1063
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only the input is different here! if you first embed the input, then you don't need two branches!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the old code contains a torch.no_grad here. That's why there are two branches. Did you all mean to wrap the unpad function in a no_grad? If so, we need to keep it. We definitely can't unpad inputs_embeds in a no_grad block because it will break gradient flow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I mentioned last time, I still don't understand why you would need gradient flow for padding / unpadding when it's weight agnostic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe #35386 is a more detailed bug report on the gradient breaks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad! Indeed for input embedding you need grads if you are training an encoder!

Comment on lines 1227 to 1228
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH it might make more sense! This way only the base model unpads, and the other models can freely do the unpadding. Kinda up to you 🤗

@NohTow
Copy link
Contributor

NohTow commented Jan 7, 2025

Hello,
Sorry for the delay.
As mentioned earlier, the code looks good to me.
I only have two small nitpick. As mentioned by @ArthurZucker, first embedding the input_ids should be cleaner and less error prone.
The second is that the XOR test is somewhat a bit less informative than checking if the two are not set together and then checking if at least one is defined. Maybe it's just me, but I find "You must specify exactly one of input_ids or inputs_embeds" a bit ambiguous, yet I don't have a better wording.

Besides those very small details, everything looks good!

@ArthurZucker
Copy link
Collaborator

The xor test is on a few models like Llama -> it's standard for us! 🤗

@jxmorris12
Copy link
Contributor Author

Sorry @NohTow - can you make a concrete suggestion on how to fix the code? It's not clear how the code can be much cleaner since all the conditionals seem necessary to me.

@NohTow
Copy link
Contributor

NohTow commented Jan 8, 2025

@jxmorris12 I am very sorry my message was unclear. I did not mean that the code was incorrect nor that testing the two conditions sequentially was cleaner, I just meant that from my reading, the error message could be ambiguous (in the case where the user do not feed anything, the exactly might not be explicit enough).
But again, that is really a nitpick and probably comes from me not properly reading the message, let's ignore that!

@jxmorris12
Copy link
Contributor Author

@jon-tow Oh okay, that makes sense to me! I agree it could be a really confusing error for the user. But maybe since it's in a lot of files (such as LLAMA) we could open a separate issue to improve that error message everywhere?

@NohTow
Copy link
Contributor

NohTow commented Jan 8, 2025

@jxmorris12 Yeah totally, I raised that because I compared the implementations from BERT and this one and the BERT one does the check sequentially and thus have more informative messages, but since Arthur raised that it is already done like that in other models and is standard now, let's just follow the standard!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments are not adressed 😓

  • let's embed inputs, then we computes shapes
  • let's never require grad on unpadding / padding as it's weight agnostic

@NohTow
Copy link
Contributor

NohTow commented Jan 9, 2025

For the latter, I think it is related to this.
It was discussed during the original PR, but it breaks the gradients.
I'll let @warner-benjamin give more information.

@tomaarsen
Copy link
Member

My comments are not adressed 😓

* let's embed inputs, then we computes shapes

This is a very logical approach, so that we don't need

            if inputs_embeds is not None:
                batch_size, seq_len = inputs_embeds.shape[:2]
            else:
                batch_size, seq_len = input_ids.shape[:2]

But instead can just use

batch_size, seq_len = hidden_states.shape[:2]

However, in ModernBERT, before we can compute the embedded inputs, we need to potentially unpad. To unpad, we need to have an attention_mask, which needs the shapes. In short: we need the shapes before the input embeddings.

One alternative is to use attention_mask = torch.ones((input_ids or input_embeds).shape[:2], ...), but this feels like the same as the shape computation but as a 1-liner.

* let's never require grad on unpadding / padding as it's weight agnostic

Please have a look at #35386, it looks like the torch.no_grad is killing the gradients that were there before re-padding. Running the same script from that PR on BERT does give gradients on the logits.

@warner-benjamin made a PR to patch it here: #35404, which makes the gradient for repadding optional.

  • Tom Aarsen

@ArthurZucker
Copy link
Collaborator

Got it! Okay, for output, we do need the gradients, my bad!
TLDR:

  • when training: with torch no grad for input, but you need output gradients
  • when training: with input embeddings: torch grad because you need propagation
  • when inference, never grad, but usually people do this outside the modeling.

So the only "gain" is using no grad on input ids all the time

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for answering my questions! Let's go 🤗

Comment on lines +1056 to 1063
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad! Indeed for input embedding you need grads if you are training an encoder!

@tomaarsen tomaarsen merged commit 832c619 into huggingface:main Jan 9, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants