diff --git a/model2vec/distill/tokenizer.py b/model2vec/distill/tokenizer.py index 173a2f2..e6c3087 100644 --- a/model2vec/distill/tokenizer.py +++ b/model2vec/distill/tokenizer.py @@ -2,6 +2,7 @@ import json import logging +from typing import Any from tokenizers import Tokenizer @@ -36,11 +37,11 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize logger.info("No tokens to remove.") return Tokenizer.from_str(tokenizer.to_str()) - tokenizer_data = json.loads(tokenizer.to_str()) + tokenizer_data: dict[str, Any] = json.loads(tokenizer.to_str()) # Find all added tokens - added_tokens = tokenizer_data["added_tokens"] - added_tokens_str = {token["content"] for token in added_tokens} + added_tokens: list[dict[str, Any]] = tokenizer_data.get("added_tokens", []) + added_tokens_str: set[str] = {token["content"] for token in added_tokens} # Remove all added tokens from the list of tokens to remove. # Things will go bad if we keep them. @@ -76,9 +77,9 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize raise ValueError(f"Unknown model type {model_type}") # Reindex the special tokens (i.e., CLS and SEP for BertTokenizers.) - special_tokens_post_processor: dict[str, dict] = tokenizer_data["post_processor"]["special_tokens"] - for token, token_data in special_tokens_post_processor.items(): - token_data["ids"] = [reindexed[token] for token in token_data["tokens"]] + added_tokens = tokenizer_data.get("added_tokens", []) + for token_data in added_tokens: + token_data["id"] = reindexed[token_data["content"]] # Reinitialize the tokenizer from the json. tokenizer = Tokenizer.from_str(json.dumps(tokenizer_data))