Skip to content

Commit

Permalink
fix: don't rely on reported vocab size, log warning if inconsistent (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Oct 23, 2024
1 parent 2f57b9c commit 5ed70a6
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,15 @@ def create_output_embeddings_from_model_name(
:return: The tokens and output embeddings.
"""
model = model.to(device)
ids = torch.arange(tokenizer.vocab_size)

# Quick check to see if the tokenizer is consistent.
vocab_length = len(tokenizer.get_vocab())
if vocab_length != tokenizer.vocab_size:
logger.warning(
f"Reported vocab size {tokenizer.vocab_size} is inconsistent with the vocab size {vocab_length}."
)

ids = torch.arange(vocab_length)

# Work-around to get the eos and bos token ids without having to go into tokenizer internals.
dummy_encoding = tokenizer.encode("A")
Expand Down

0 comments on commit 5ed70a6

Please sign in to comment.