Skip to content

Commit

Permalink
Fix tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jun 12, 2024
1 parent 8904605 commit 2a374c2
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 33 deletions.
26 changes: 13 additions & 13 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,8 +788,6 @@ def _standardize_dataset(examples):
pass


import re

def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
added_tokens_decoder = tokenizer.added_tokens_decoder.values()
added_tokens_decoder = [str(x) for x in added_tokens_decoder]
Expand Down Expand Up @@ -875,6 +873,15 @@ def construct_chat_template( \
pass
pass

error_msg = \
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
"and the assistant output {OUTPUT}\n\n"\
"For example what is not allowed is just:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
"What is required is 2x of this:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"

# O(N^2) search finding 2 repeatted pieces of text
j = len(template)-1
at_least_one = False
Expand All @@ -885,22 +892,15 @@ def construct_chat_template( \
at_least_one = True
pass
if j > 0: j += 1
else: raise
else: raise RuntimeError(error_msg)


if not at_least_one: raise
if not at_least_one: raise RuntimeError(error_msg)

# Repeatted text
instruction_response = template[j:]
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
raise RuntimeError(
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
"and the assistant output {OUTPUT}\n\n"\
"For example what is not allowed is just:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
"What is required is 2x of this:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
)
raise RuntimeError(error_msg)
pass

# 1st System, Instruction, Output pair
Expand Down
166 changes: 146 additions & 20 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,111 @@ def convert_to_fast_tokenizer(
pass


# Check Mistral chat template without BOS / EOS
mistral_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ message['content'] }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass

# Check Llama chat template without BOS / EOS
llama_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ ' ' + message['content'].strip() + ' ' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass


def select_correct_slow_tokenizer(
tokenizer_name,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
):
"""
Returns 'correct' tokenizer by checking if the chat templates are
actually tokenized correctly.
"""
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
]

settings = (
(False, False, True,),
(False, True, True,),
(True, False, True,),
(True, False, False,),
)

for (use_fast, legacy, from_slow,) in settings:
# Default as mentioned by Arthur from HF:
slow_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast = use_fast,
legacy = legacy,
from_slow = from_slow,
cache_dir = cache_dir,
)
slow_tokenizer_chat_template = slow_tokenizer.chat_template

slow_tokenizer.chat_template = llama_template
result1 = slow_tokenizer.decode(slow_tokenizer.apply_chat_template(messages))
slow_tokenizer.chat_template = mistral_template
result2 = slow_tokenizer.decode(slow_tokenizer.apply_chat_template(messages))

# If 2 spaces seen, normally wrong!
if " "*2 not in result1 and " "*2 not in result2:
slow_tokenizer.chat_template = slow_tokenizer_chat_template
return slow_tokenizer
pass
pass
# Return fast version as default
return slow_tokenizer
pass


def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Get eos_token, bos_token etc
dir_names = dir(slow_tokenizer)
Expand All @@ -195,21 +300,44 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))

# Check if chat template is enabled!
check_chat_template = True
check_chat_template1 = True
check_chat_template2 = True
check_chat_template3 = True
slow_chat_template = getattr(slow_tokenizer, "chat_template", None)
fast_chat_template = getattr(fast_tokenizer, "chat_template", None)
messages = [
{"role": "user", "content": " What is 2+2? "},
{"role": "assistant", "content": " It's 4. "},
]
# Check the tokenizer's own chat template
if slow_chat_template is not None and fast_chat_template is not None:
check_chat_template1 = \
slow_tokenizer.apply_chat_template(messages) == \
fast_tokenizer.apply_chat_template(messages)
pass

if getattr(slow_tokenizer, "chat_template", None) is not None and \
getattr(fast_tokenizer, "chat_template", None) is not None:
# Check Mistral chat template without BOS / EOS
slow_tokenizer.chat_template = mistral_template
fast_tokenizer.chat_template = mistral_template
check_chat_template2 = \
slow_tokenizer.apply_chat_template(messages) == \
fast_tokenizer.apply_chat_template(messages)
pass

# Check chat template!
messages = [
{"role": "user", "content": " What is 2+2? "},
{"role": "assistant", "content": " It's 4. "},
]
check_chat_template = \
slow_tokenizer(slow_tokenizer.apply_chat_template(messages)).input_ids == \
fast_tokenizer(slow_tokenizer.apply_chat_template(messages)).input_ids
# Check Llama chat template without BOS / EOS
slow_tokenizer.chat_template = llama_template
fast_tokenizer.chat_template = llama_template
check_chat_template3 = \
slow_tokenizer.apply_chat_template(messages) == \
fast_tokenizer.apply_chat_template(messages)
pass

# Combine them all and revert chat templates
check_chat_template = check_chat_template1 and check_chat_template2 and check_chat_template3
slow_tokenizer.chat_template = slow_chat_template
fast_tokenizer.chat_template = fast_chat_template

# Try special tokens
try:
string = "\n".join(all_special_tokens) + \
"A quick brown fox jumps over the lazy dog!!\n\nHi</s>\n\n" + \
Expand All @@ -227,6 +355,7 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return check_chat_template
else:
return False
pass
pass


Expand Down Expand Up @@ -379,17 +508,13 @@ def load_correct_tokenizer(
# Mainly to solve Deepseek models with no tokenizer.model file
slow_tokenizer = None
try:
slow_tokenizer = AutoTokenizer.from_pretrained(
slow_tokenizer = select_correct_slow_tokenizer(
tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
cache_dir = cache_dir,
)
except:
pass
Expand Down Expand Up @@ -418,6 +543,7 @@ def load_correct_tokenizer(
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
else:
logger.warning(f"Unsloth: Will load {tokenizer_name} as a legacy tokenizer.")
return convert_to_fast_tokenizer(slow_tokenizer)
pass
else:
Expand Down

0 comments on commit 2a374c2

Please sign in to comment.