-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update modular_modernbert -- add inputs_embeds param to ModernBertModel #35373
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yep would be nice to add. let's make it a bit simpler (si how llama has less if elses) and add a small test ! 🤗
First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented. I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes. I also introduced an error if input_ids and input_embeds are both or neither provided. Lastly, I fixed an issue with device being based solely on input_ids with attention_mask.
Also reintroduce inputs_embeds test
Hello @jxmorris12, @ArthurZucker, I pushed some changes into this PR to get it closer to completion. Let me know if you're not okay with this, and you can easily revert or delete the commits. The changes:
This is sadly not as simple as it seems due to
|
if self.config._attn_implementation == "flash_attention_2": | ||
if indices is None and cu_seqlens is None and max_seqlen is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit frustrating that this entire tree is necessary, but there's no other convenient way to avoid the base model from repadding while allowing this class to repad, as then the base model would also have to return the batch size, indices, seqlens, etc. so this class could repad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH it might make more sense! This way only the base model unpads, and the other models can freely do the unpadding. Kinda up to you 🤗
Hello, Thanks for opening this PR, I really love your work on CDE and I look forward to seeing how it will perform with ModernBERT! I'll make a more thorough review when I come back from vacation this week, but I already checked how to implement it for the linked issue and it seems to be in line with the latest change from Tom. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
thanks Tom! Looks great. |
Hi @NohTow @ArthurZucker just following up on this -- thanks :) |
when will it be merged into the main branch? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating, let's remove complexity and good to go!
if inputs_embeds is not None: | ||
batch_size, seq_len = inputs_embeds.shape[:2] | ||
else: | ||
batch_size, seq_len = input_ids.shape[:2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you first embed input ids you don't need two branches!
with torch.no_grad(): | ||
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( | ||
inputs=input_ids, attention_mask=attention_mask | ||
) | ||
else: | ||
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( | ||
inputs=inputs_embeds, attention_mask=attention_mask | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only the input is different here! if you first embed the input, then you don't need two branches!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well the old code contains a torch.no_grad here. That's why there are two branches. Did you all mean to wrap the unpad function in a no_grad? If so, we need to keep it. We definitely can't unpad inputs_embeds in a no_grad block because it will break gradient flow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I mentioned last time, I still don't understand why you would need gradient flow for padding / unpadding when it's weight agnostic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe #35386 is a more detailed bug report on the gradient breaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad! Indeed for input embedding you need grads if you are training an encoder!
if self.config._attn_implementation == "flash_attention_2": | ||
if indices is None and cu_seqlens is None and max_seqlen is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH it might make more sense! This way only the base model unpads, and the other models can freely do the unpadding. Kinda up to you 🤗
Hello, Besides those very small details, everything looks good! |
The xor test is on a few models like Llama -> it's standard for us! 🤗 |
Sorry @NohTow - can you make a concrete suggestion on how to fix the code? It's not clear how the code can be much cleaner since all the conditionals seem necessary to me. |
@jxmorris12 I am very sorry my message was unclear. I did not mean that the code was incorrect nor that testing the two conditions sequentially was cleaner, I just meant that from my reading, the error message could be ambiguous (in the case where the user do not feed anything, the exactly might not be explicit enough). |
@jon-tow Oh okay, that makes sense to me! I agree it could be a really confusing error for the user. But maybe since it's in a lot of files (such as LLAMA) we could open a separate issue to improve that error message everywhere? |
@jxmorris12 Yeah totally, I raised that because I compared the implementations from BERT and this one and the BERT one does the check sequentially and thus have more informative messages, but since Arthur raised that it is already done like that in other models and is standard now, let's just follow the standard! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comments are not adressed 😓
- let's embed inputs, then we computes shapes
- let's never require grad on unpadding / padding as it's weight agnostic
For the latter, I think it is related to this. |
This is a very logical approach, so that we don't need if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2] But instead can just use batch_size, seq_len = hidden_states.shape[:2] However, in ModernBERT, before we can compute the embedded inputs, we need to potentially unpad. To unpad, we need to have an One alternative is to use
Please have a look at #35386, it looks like the @warner-benjamin made a PR to patch it here: #35404, which makes the gradient for repadding optional.
|
Got it! Okay, for output, we do need the gradients, my bad!
So the only "gain" is using no grad on input ids all the time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for answering my questions! Let's go 🤗
with torch.no_grad(): | ||
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( | ||
inputs=input_ids, attention_mask=attention_mask | ||
) | ||
else: | ||
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( | ||
inputs=inputs_embeds, attention_mask=attention_mask | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad! Indeed for input embedding you need grads if you are training an encoder!
What does this PR do?
Hi! Congrats on the release of ModernBERT; it looks amazing. I'm interested in using ModernBERT eventually to train a new Contextual Document Embeddings model.
One desired feature is to pass the contextual and word embeddings together in the second stage, which requires setting the
inputs_embeds
kwarg so that we can pass hidden states directly. This is a feature of typical BERT and other transformer implementations but isn't yet allowed by ModernBERT, so I added it. It's only a few additional lines of code.cc: @warner-benjamin @tomaarsen @orionw @staghado @bclavie @NohTow @ArthurZucker