Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support modernBERT for encoder-decoder models #35385

Open
Bachstelze opened this issue Dec 21, 2024 · 2 comments
Open

Support modernBERT for encoder-decoder models #35385

Bachstelze opened this issue Dec 21, 2024 · 2 comments
Labels
Feature request Request for a new feature

Comments

@Bachstelze
Copy link

Feature request

The docs state that the EncoderDecoderModel can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the encoder. Though ModernBERT isn't supported:

File "/content/syntax_transformer/data/../models/encoderDecoder.py", line 40, in __init__
    self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 538, in from_encoder_decoder_pretrained
    decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 567, in from_pretrained
    raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers.models.modernbert.configuration_modernbert.ModernBertConfig'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, ElectraConfig, ErnieConfig, FalconConfig, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, GitConfig, GlmConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GraniteConfig, GraniteMoeConfig, JambaConfig, JetMoeConfig, LlamaConfig, MambaConfig, Mamba2Config, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig, MllamaConfig, MoshiConfig, MptConfig, MusicgenConfig, MusicgenMelodyConfig, MvpConfig, NemotronConfig, OlmoConfig, Olmo2Config, OlmoeConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PersimmonConfig, PhiConfig, Phi3Config, PhimoeConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, Qwen2Config, Qwen2MoeConfig, RecurrentGemmaConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, Speech2Text2Config, StableLmConfig, Starcoder2Config, TransfoXLConfig, TrOCRConfig, WhisperConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, ZambaConfig.

Motivation

ModernBert has a better performance and a longer context length.

Your contribution

How is it possible to support monderBERT? It isn't that different from other BERT models.

@Bachstelze Bachstelze added the Feature request Request for a new feature label Dec 21, 2024
@NielsRogge
Copy link
Contributor

The reason ModernBERT isn't supported yet to be used as decoder is because it does not include a cross-attention module.

When you use the EncoderDecoderModel class and want to initialize the weights of the decoder with those of a pre-trained encoder-only one (like ModernBERT), the modeling_xxx.py file needs to support cross-attention (and causal attention mask). This is supported in modeling_bert.py as can be seen here. But for ModernBERT, explicit support for a config.is_decoder argument (and corresponding implementation) would need to be added.

@Bachstelze
Copy link
Author

To include a cross-attention module those other modules should be changed:

if self.add_cross_attention:
   if not self.is_decoder:
      raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
   self.crossattention = ModernBertAttention(config)

in ModernBertEncoderLayer init(), better rename it to ModernBertLayer.

If ModernBertAttention is instantiated as a cross-attention module, the keys and values come from an encoder; the attention mask needs to be such that the encoder's padding tokens are not attended to.

modeling_modernbert.py is generated from modular_modernbert.py.
How can this generation be triggered?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

2 participants