Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging DiVA to Levanter Main #779

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Merging DiVA to Levanter Main #779

wants to merge 5 commits into from

Conversation

Helw150
Copy link
Collaborator

@Helw150 Helw150 commented Oct 30, 2024

Cleaned up version of my code for the Distilled Voice Assistant models that I trained using a fork of Levanter!

@dlwh Main thing I want to check in with you here is what the appropriate design pattern you think would make sense for initializing the model weights from multiple other pretrained models would be! What I've done here is much cleaner than what I did originally for the paper, but still feels a bit messy.

Testing Procedure for the correctness of this training code:
I trained a new DiVA model with this updated code and Llama 3.2 1B using the config in diva_flash.yaml.

Training Log is here: https://wandb.ai/i18nlp/levanter/runs/jnxp463y?nw=nwuserheld
Resulting model is on HF in PyTorch form here: https://huggingface.co/WillHeld/DiVA-llama-3.2-1b
Demo which confirmed the result is ~reasonable here for now: https://b3f161194b514a990f.gradio.live/

@Helw150 Helw150 requested a review from dlwh October 30, 2024 04:32
@dlwh
Copy link
Member

dlwh commented Oct 30, 2024

I am currently annoyed by how we initialize models and this seems fine enough (cf #780 ) so I don't have a super strong feeling right now on it. You could look to how we do Lora if you want, but that's a bit of a different case.

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Few minor comments. I don't understand everything but overall seems good to me!

lev_model: DivaModel = super().load_pretrained(
DivaModel, ref, config, axis_mapping, resize_vocab_to_match_tokenizer, dtype
) # type: ignore[assignment]
llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = HFCheckpointConverter(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this Union around your 3.10 union?

elif "gemma" in model_id:
config = GemmaConfig.from_hf_config(hf_config)
elif "mistral" in model_id:
config = MistralConfig.from_hf_config(hf_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you raise a better error if it's none of these?

return config


def get_prefix(tokenizer_ref):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doccomment maybe?

return prefix_tok, suffix_tok


@LmConfig.register_subclass("diva")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsrConfig?

init_from_submodel: bool = True

# Connector Config
pre_audio_prompt = property(lambda self: get_prefix(self.reference_decoder)[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_property or who cares?

)


class DivaModel(eqx.Module, ModelWithHfSerializationMixin[DivaConfig]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe link to paper or something?


# Convert to Virtual LLM Tokens
virt_whisper_tokens = self.connector.transformer(
(self.query_tokens + self.query_position_embeds).broadcast_axis(OtherAxes),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hrm i wouldn't think this should be necessary

text[
{
"batch": hax.arange(Batch),
"position": (hax.sum(text_tokens == pad_token_id, "position") * -1) - 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you just broke my brain

kl_proxy_loss = hax.dot(diff_distill, diff_distill, axis="embed") ** 0.5

# Compute Contrastive Loss on Input
# Correct for Normal Autoregressive Loss Mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to check that attn mask is causal or eh?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants