-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merging DiVA to Levanter Main #779
base: main
Are you sure you want to change the base?
Conversation
I am currently annoyed by how we initialize models and this seems fine enough (cf #780 ) so I don't have a super strong feeling right now on it. You could look to how we do Lora if you want, but that's a bit of a different case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Few minor comments. I don't understand everything but overall seems good to me!
lev_model: DivaModel = super().load_pretrained( | ||
DivaModel, ref, config, axis_mapping, resize_vocab_to_match_tokenizer, dtype | ||
) # type: ignore[assignment] | ||
llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = HFCheckpointConverter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this Union around your 3.10 union?
elif "gemma" in model_id: | ||
config = GemmaConfig.from_hf_config(hf_config) | ||
elif "mistral" in model_id: | ||
config = MistralConfig.from_hf_config(hf_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you raise a better error if it's none of these?
return config | ||
|
||
|
||
def get_prefix(tokenizer_ref): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doccomment maybe?
return prefix_tok, suffix_tok | ||
|
||
|
||
@LmConfig.register_subclass("diva") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AsrConfig?
init_from_submodel: bool = True | ||
|
||
# Connector Config | ||
pre_audio_prompt = property(lambda self: get_prefix(self.reference_decoder)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache_property or who cares?
) | ||
|
||
|
||
class DivaModel(eqx.Module, ModelWithHfSerializationMixin[DivaConfig]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe link to paper or something?
|
||
# Convert to Virtual LLM Tokens | ||
virt_whisper_tokens = self.connector.transformer( | ||
(self.query_tokens + self.query_position_embeds).broadcast_axis(OtherAxes), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hrm i wouldn't think this should be necessary
text[ | ||
{ | ||
"batch": hax.arange(Batch), | ||
"position": (hax.sum(text_tokens == pad_token_id, "position") * -1) - 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you just broke my brain
kl_proxy_loss = hax.dot(diff_distill, diff_distill, axis="embed") ** 0.5 | ||
|
||
# Compute Contrastive Loss on Input | ||
# Correct for Normal Autoregressive Loss Mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to check that attn mask is causal or eh?
Cleaned up version of my code for the Distilled Voice Assistant models that I trained using a fork of Levanter!
@dlwh Main thing I want to check in with you here is what the appropriate design pattern you think would make sense for initializing the model weights from multiple other pretrained models would be! What I've done here is much cleaner than what I did originally for the paper, but still feels a bit messy.
Testing Procedure for the correctness of this training code:
I trained a new DiVA model with this updated code and Llama 3.2 1B using the config in
diva_flash.yaml
.Training Log is here: https://wandb.ai/i18nlp/levanter/runs/jnxp463y?nw=nwuserheld
Resulting model is on HF in PyTorch form here: https://huggingface.co/WillHeld/DiVA-llama-3.2-1b
Demo which confirmed the result is ~reasonable here for now: https://b3f161194b514a990f.gradio.live/