Skip to content

Commit

Permalink
fix: update added tokens to be more agnostic (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Oct 23, 2024
1 parent 9c73fdf commit 2f57b9c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions model2vec/distill/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
from typing import Any

from tokenizers import Tokenizer

Expand Down Expand Up @@ -36,11 +37,11 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
logger.info("No tokens to remove.")
return Tokenizer.from_str(tokenizer.to_str())

tokenizer_data = json.loads(tokenizer.to_str())
tokenizer_data: dict[str, Any] = json.loads(tokenizer.to_str())

# Find all added tokens
added_tokens = tokenizer_data["added_tokens"]
added_tokens_str = {token["content"] for token in added_tokens}
added_tokens: list[dict[str, Any]] = tokenizer_data.get("added_tokens", [])
added_tokens_str: set[str] = {token["content"] for token in added_tokens}

# Remove all added tokens from the list of tokens to remove.
# Things will go bad if we keep them.
Expand Down Expand Up @@ -76,9 +77,9 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
raise ValueError(f"Unknown model type {model_type}")

# Reindex the special tokens (i.e., CLS and SEP for BertTokenizers.)
special_tokens_post_processor: dict[str, dict] = tokenizer_data["post_processor"]["special_tokens"]
for token, token_data in special_tokens_post_processor.items():
token_data["ids"] = [reindexed[token] for token in token_data["tokens"]]
added_tokens = tokenizer_data.get("added_tokens", [])
for token_data in added_tokens:
token_data["id"] = reindexed[token_data["content"]]

# Reinitialize the tokenizer from the json.
tokenizer = Tokenizer.from_str(json.dumps(tokenizer_data))
Expand Down

0 comments on commit 2f57b9c

Please sign in to comment.