Skip to content

Commit

Permalink
fix: don't rely on reported vocab size, log warning if inconsistent
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Oct 23, 2024
1 parent 9c73fdf commit 30933a5
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,15 @@ def create_output_embeddings_from_model_name(
:return: The tokens and output embeddings.
"""
model = model.to(device)
ids = torch.arange(tokenizer.vocab_size)

# Quick check to see if the tokenizer is consistent.
vocab_length = len(tokenizer.get_vocab())
if vocab_length != tokenizer.vocab_size:
logger.warning(

Check warning on line 122 in model2vec/distill/inference.py

View check run for this annotation

Codecov / codecov/patch

model2vec/distill/inference.py#L122

Added line #L122 was not covered by tests
f"Reported vocab size {tokenizer.vocab_size} is inconsistent with the vocab size {vocab_length}."
)

ids = torch.arange(vocab_length)

# Work-around to get the eos and bos token ids without having to go into tokenizer internals.
dummy_encoding = tokenizer.encode("A")
Expand Down

0 comments on commit 30933a5

Please sign in to comment.