Skip to content

Commit

Permalink
fix: Add explicit errors for BPE and unigram, return tokenizer withou…
Browse files Browse the repository at this point in the history
…t cha… (#54)

* Add explicit errors for BPE and unigram, return tokenizer without changes if tokens_to_remove is empty

* Add helpful message

* Clearer message

* Add nice logging

* Remove clunky sentence
  • Loading branch information
stephantul authored Oct 3, 2024
1 parent c7bcae8 commit c34d862
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 18 deletions.
11 changes: 10 additions & 1 deletion model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from huggingface_hub import model_info
from sklearn.decomposition import PCA
from tokenizers.models import BPE, Unigram, WordPiece
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

from model2vec.distill.inference import (
Expand Down Expand Up @@ -50,6 +51,14 @@ def distill(

# Load original tokenizer. We need to keep this to tokenize any tokens in the vocabulary.
original_tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name)

if vocabulary and isinstance(original_tokenizer.backend_tokenizer.model, (BPE, Unigram)):
raise ValueError(
"You passed a vocabulary, but the model you are using does not use a WordPiece tokenizer. "
"This is not supported yet."
"Feel free to open an issue if this is a blocker: https://github.com/MinishLab/model2vec/issues"
)

original_model: PreTrainedModel = AutoModel.from_pretrained(model_name)
# Make a base list of tokens.
tokens: list[str] = []
Expand Down Expand Up @@ -79,7 +88,7 @@ def distill(
# We need to set embeddings to None because we don't know the dimensions of the embeddings yet.
embeddings = None

if vocabulary is not None:
if vocabulary:
# Preprocess the vocabulary with the original tokenizer.
preprocessed_vocabulary = preprocess_vocabulary(original_tokenizer.backend_tokenizer, vocabulary)
n_tokens_before = len(preprocessed_vocabulary)
Expand Down
69 changes: 52 additions & 17 deletions model2vec/distill/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,19 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
:param tokenizer: The tokenizer to remove tokens from.
:param tokens_to_remove: The tokens to remove.
:return: The modified tokenizer.
:raises ValueError: If the tokenizer model type is not supported.
"""
model_vocab = set(tokenizer.get_vocab())
# This triggers when tokens_to_remove is empty or when there is no overlap
# between the tokens to remove and the model vocabulary.
if not set(tokens_to_remove).intersection(model_vocab):
# NOTE: return a copy.
if tokens_to_remove:
logger.info("No tokens to remove, none of the tokens were in the vocabulary.")
else:
logger.info("No tokens to remove.")
return Tokenizer.from_str(tokenizer.to_str())

tokenizer_data = json.loads(tokenizer.to_str())

# Find all added tokens
Expand All @@ -35,20 +47,31 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
tokens_to_remove = [token for token in tokens_to_remove if token not in added_tokens_str]

# Load the vocabulary.
vocab: dict[str, int] = tokenizer_data["model"]["vocab"]
n_tokens = len(vocab)

# Remove the tokens.
for token in tokens_to_remove:
if vocab.pop(token, None) is None:
logger.warning(f"Token {token} was not in the vocabulary.")

n_removed = n_tokens - len(vocab)
logger.info(f"Removed {n_removed} tokens from the vocabulary.")

# Reindex the vocabulary so that it is contiguous.
reindexed = {token: idx for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: x[1]))}
tokenizer_data["model"]["vocab"] = reindexed
model_type = tokenizer_data["model"]["type"]

match model_type:
case "WordPiece":
# Vocab is a dictionary.
vocab: dict[str, int] = tokenizer_data["model"]["vocab"]
n_tokens = len(vocab)

# Remove the tokens.
for token in tokens_to_remove:
if vocab.pop(token, None) is None:
logger.warning(f"Token {token} was not in the vocabulary.")

n_removed = n_tokens - len(vocab)
logger.info(f"Removed {n_removed} tokens from the vocabulary.")

# Reindex the vocabulary so that it is contiguous.
reindexed = {token: idx for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: x[1]))}
tokenizer_data["model"]["vocab"] = reindexed
case "Unigram":
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")
case "BPE":
raise ValueError("Removing tokens from a bpe tokenizer is not supported.")
case _:
raise ValueError(f"Unknown model type {model_type}")

# Reindex the special tokens (i.e., CLS and SEP for BertTokenizers.)
special_tokens_post_processor: dict[str, dict] = tokenizer_data["post_processor"]["special_tokens"]
Expand All @@ -68,12 +91,24 @@ def add_tokens(tokenizer: Tokenizer, tokens_to_add: list[str]) -> Tokenizer:
:param tokenizer: The tokenizer to add tokens to.
:param tokens_to_add: The tokens to add.
:return: The modified tokenizer.
:raises ValueError: If the tokenizer model type is not supported.
"""
data = json.loads(tokenizer.to_str())

vocab: dict[str, int] = data["model"]["vocab"]
for token in tokens_to_add:
vocab[token] = len(vocab)
model = data["model"]["type"]

match model:
case "WordPiece":
wordpiece_vocab: dict[str, int] = data["model"]["vocab"]
for token in tokens_to_add:
if token not in wordpiece_vocab:
wordpiece_vocab[token] = len(wordpiece_vocab)
case "Unigram":
raise ValueError("Adding tokens to a unigram tokenizer is not supported.")
case "BPE":
raise ValueError("Adding tokens to a bpe tokenizer is not supported.")
case _:
raise ValueError(f"Unknown model type {model}")

tokenizer = Tokenizer.from_str(json.dumps(data))

Expand Down

0 comments on commit c34d862

Please sign in to comment.