diff --git a/src/formatron/integrations/RWKV.py b/src/formatron/integrations/RWKV.py index 06cae120..b304d773 100644 --- a/src/formatron/integrations/RWKV.py +++ b/src/formatron/integrations/RWKV.py @@ -8,7 +8,7 @@ from formatron.config import EngineGenerationConfig from formatron.formatter import FormatterBuilder - +__all__ = ["create_engine_vocabulary", "PIPELINE", "PIPELINE_ARGS"] class PIPELINE_ARGS(rwkv.utils.PIPELINE_ARGS): """ A wrapper for the arguments of the pipeline of RWKV. diff --git a/src/formatron/integrations/_utils.py b/src/formatron/integrations/_utils.py deleted file mode 100644 index f4499be1..00000000 --- a/src/formatron/integrations/_utils.py +++ /dev/null @@ -1,68 +0,0 @@ -import re -import typing -from functools import lru_cache - - -def _multiple_replace(replacements: typing.Dict[bytes, bytes], regex: re.Pattern[bytes], text: bytes) -> bytes: - # For each match, look-up corresponding value in dictionary - return regex.sub(lambda mo: replacements[mo.group()], text) - - -Processors = set[typing.Literal["sentencepiece", "<0xHH>", "dot_G"]] - - -def _autodetect_processors(vocab: typing.Dict[str, int]): - result = set() - llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys()) - underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2 - g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2 - if llama_present: - result.add("<0xHH>") - if underscore_present: - result.add("sentencepiece") - elif g_present: - result.add("dot_G") - return result - - -def get_original_characters(vocab: typing.Dict[str, int]) -> typing.Dict[int, bytes]: - old_char_to_new_char = {} - assert len(set(vocab.values())) == len(vocab), "Vocabulary contains duplicate token IDs!" - processors = _autodetect_processors(vocab) - for i in processors: - if i == "sentencepiece": - old_char_to_new_char["\u2581".encode("UTF-8")] = b" " - elif i == "dot_G": - old_char_to_new_char.update(huggingface_bytelevel_decoder()) - elif i == "<0xHH>": - for j in range(256): - old_char_to_new_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j]) - else: - raise ValueError(f"{i} is not a valid processor name!") - # Create a regular expression from the dictionary keys with longest keys first to avoid conflicts - regex = re.compile(b"(%s)" % b"|".join(sorted(list(map(re.escape, old_char_to_new_char.keys())), key=lambda x: len(x), reverse=True))) - new_vocab = {} - for k in vocab: - token_id = vocab[k] - new_k = _multiple_replace(old_char_to_new_char, regex, k.encode("UTF-8")) - new_vocab[token_id] = new_k - return new_vocab - - -@lru_cache() -def huggingface_bytelevel_decoder(): - """ - I hate legacy code. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n).encode("UTF-8") for n in cs] - for i in range(len(bs)): - bs[i] = bytes([bs[i]]) - return dict(zip(cs, bs)) \ No newline at end of file diff --git a/src/formatron/integrations/exllamav2.py b/src/formatron/integrations/exllamav2.py index f1954a31..14b17018 100644 --- a/src/formatron/integrations/exllamav2.py +++ b/src/formatron/integrations/exllamav2.py @@ -9,31 +9,44 @@ from exllamav2.generator.base import ExLlamaV2Filter from formatron.config import EngineGenerationConfig from formatron.formatter import FormatterBase, FormatterBuilder -from formatron.integrations._utils import get_original_characters -from functools import lru_cache +from formatron.integrations.utils import get_original_characters -def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer) -> kbnf.Vocabulary: +__all__ = ["create_engine_vocabulary", "create_formatter_filter", "FormatterFilter"] +def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer, + vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary: """ Create a vocabulary for the KBNF engine. + Args: + tokenizer: The tokenizer. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ assert hasattr(tokenizer.tokenizer_model, "vocab"), (f"tokenizer({tokenizer})" f" with tokenizer_model({tokenizer.tokenizer_model})" f" does not have vocab attribute!") vocab = {tokenizer.tokenizer_model.id_to_piece( i): i for i in range(tokenizer.tokenizer_model.vocab_size())} - new_vocab = get_original_characters(vocab) + new_vocab = get_original_characters(vocab, vocab_processors) return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()}, {v: k for k, v in vocab.items()}) def create_formatter_filter(model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder, - engine_config: EngineGenerationConfig = None) -> ExLlamaV2Filter: + engine_config: EngineGenerationConfig = None, + vocab_processors: typing.Optional[list[typing.Callable]] = None) -> ExLlamaV2Filter: """ Create a formatter filter for the ExLlamaV2 engine. + Args: + model: The ExLlamaV2 model. + tokenizer: The ExLlamaV2 tokenizer. + formatter_builder: The formatter builder. + engine_config: The engine generation configuration. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ - vocab = create_engine_vocabulary(tokenizer) + vocab = create_engine_vocabulary(tokenizer, vocab_processors) f = formatter_builder.build( vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))) return FormatterFilter(model, tokenizer, f, engine_config) diff --git a/src/formatron/integrations/transformers.py b/src/formatron/integrations/transformers.py index 223e2b27..3ef5950d 100644 --- a/src/formatron/integrations/transformers.py +++ b/src/formatron/integrations/transformers.py @@ -9,26 +9,39 @@ from formatron.config import EngineGenerationConfig from formatron.formatter import FormatterBuilder, FormatterBase -from formatron.integrations._utils import get_original_characters +from formatron.integrations.utils import get_original_characters +__all__ = ["create_engine_vocabulary", "create_formatter_logits_processor", "create_formatter_logits_processor_list", "FormattersLogitsProcessor"] -def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase) -> kbnf.Vocabulary: +def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase, + vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary: """ Create a vocabulary for the KBNF engine. + Args: + tokenizer: The tokenizer. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ vocab = tokenizer.get_vocab() - new_vocab = get_original_characters(vocab) + new_vocab = get_original_characters(vocab, vocab_processors) return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()}, {v: k for k, v in vocab.items()}) def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase, formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder, - configs: typing.Sequence[EngineGenerationConfig] = None) -> LogitsProcessor: + configs: typing.Sequence[EngineGenerationConfig] = None, + vocab_processors: typing.Optional[list[typing.Callable]] = None) -> LogitsProcessor: """ Create a formatter logits processor. + Args: + tokenizer: The tokenizer. + formatter_builders: The formatter builders. + configs: The engine generation configurations. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ - vocab = create_engine_vocabulary(tokenizer) + vocab = create_engine_vocabulary(tokenizer, vocab_processors) if not isinstance(formatter_builders, collections.abc.Sequence): formatter_builders = [formatter_builders] formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None @@ -38,13 +51,20 @@ def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase, def create_formatter_logits_processor_list(tokenizer: PreTrainedTokenizerBase, formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder, - configs: typing.Sequence[EngineGenerationConfig] = None) \ + configs: typing.Sequence[EngineGenerationConfig] = None, + vocab_processors: typing.Optional[list[typing.Callable]] = None) \ -> LogitsProcessorList: """ Create a formatter logits processor list. + Args: + tokenizer: The tokenizer. + formatter_builders: The formatter builders. + configs: The engine generation configurations. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ return LogitsProcessorList([create_formatter_logits_processor(tokenizer, - formatter_builders, configs)]) + formatter_builders, configs, vocab_processors)]) class FormattersLogitsProcessor(LogitsProcessor): diff --git a/src/formatron/integrations/utils.py b/src/formatron/integrations/utils.py new file mode 100644 index 00000000..1b9c3e59 --- /dev/null +++ b/src/formatron/integrations/utils.py @@ -0,0 +1,93 @@ +import re +import typing +from functools import lru_cache + +__all__ = ["get_original_characters", "update_vocab_0xHH", "update_vocab_sentencepiece", "update_vocab_dot_G"] + +def _multiple_replace(replacements: typing.Dict[bytes, bytes], regex: re.Pattern[bytes], text: bytes) -> bytes: + # For each match, look-up corresponding value in dictionary + return regex.sub(lambda mo: replacements[mo.group()], text) + + +def get_original_characters(vocab: typing.Dict[str, int], + processors: typing.Optional[list[typing.Callable]] = None) -> typing.Dict[int, bytes]: + """ + Get a vocabulary of original characters unmangled to raw UTF-8 bytes by the provided processors. + + Args: + vocab: The mangled vocabulary. + processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. + """ + old_char_to_new_char = {} + assert len(set(vocab.values())) == len(vocab), "Vocabulary contains duplicate token IDs!" + if processors is None: + processors = autodetect_processors(vocab) + for update_vocab in processors: + update_vocab(old_char_to_new_char) + # Create a regular expression from the dictionary keys with longest keys first to avoid conflicts + regex = re.compile(b"(%s)" % b"|".join(sorted(list(map(re.escape, old_char_to_new_char.keys())), key=lambda x: len(x), reverse=True))) + new_vocab = {} + for k in vocab: + token_id = vocab[k] + new_k = _multiple_replace(old_char_to_new_char, regex, k.encode("UTF-8")) + new_vocab[token_id] = new_k + return new_vocab + + +def autodetect_processors(vocab: typing.Dict[str, int]) -> typing.List[typing.Callable]: + """ + Autodetect vocabulary processors. + """ + result = [] + llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys()) + underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2 + g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2 + if llama_present: + result.append(update_vocab_0xHH) + if underscore_present: + result.append(update_vocab_sentencepiece) + elif g_present: + result.append(update_vocab_dot_G) + return result + + +def update_vocab_0xHH(token_to_char: typing.Dict[bytes, bytes]): + """ + Vocabulary processor for <0xHH> tokens (used in llama tokenizers) + """ + for j in range(256): + token_to_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j]) + + +def update_vocab_sentencepiece(token_to_char: typing.Dict[bytes, bytes]): + """ + Vocabulary processor for ▁ token (used in sentencepiece tokenizers) + """ + token_to_char["\u2581".encode("UTF-8")] = b" " + + +def update_vocab_dot_G(token_to_char: typing.Dict[bytes, bytes]): + """ + Vocabulary processor for GPT2 style token mangling, like from \\n to Ġ(used in huggingface bytelevel preprocessors) + """ + token_to_char.update(_huggingface_bytelevel_decoder()) + + +@lru_cache() +def _huggingface_bytelevel_decoder(): + """ + I hate legacy code. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n).encode("UTF-8") for n in cs] + for i in range(len(bs)): + bs[i] = bytes([bs[i]]) + return dict(zip(cs, bs)) diff --git a/src/formatron/integrations/vllm.py b/src/formatron/integrations/vllm.py index d50f54f0..9335e8ba 100644 --- a/src/formatron/integrations/vllm.py +++ b/src/formatron/integrations/vllm.py @@ -2,14 +2,13 @@ This module integrates the vllm library by providing convenience utilities. """ import collections.abc -import time import typing import kbnf -import torch from vllm import LLM from formatron.config import EngineGenerationConfig from formatron.formatter import FormatterBase, FormatterBuilder -from formatron.integrations._utils import get_original_characters +from formatron.integrations.utils import get_original_characters +from vllm.transformers_utils.tokenizer import AnyTokenizer class FormattersLogitsProcessor: @@ -97,26 +96,37 @@ def __call__(self, prompt, generated_tokens, logits): return logits -def create_engine_vocabulary(llm: LLM) -> kbnf.Vocabulary: +def create_engine_vocabulary(tokenizer: AnyTokenizer, + vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary: """ Create a vocabulary for the KBNF engine. + Args: + tokenizer: The tokenizer. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ - tokenizer = llm.get_tokenizer() vocab = tokenizer.get_vocab() - new_vocab = get_original_characters(vocab) + new_vocab = get_original_characters(vocab, vocab_processors) return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()}, { v: k for k, v in vocab.items()}) def create_formatters_logits_processor(llm: LLM, formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder, - configs: typing.Sequence[EngineGenerationConfig] = None) \ + configs: typing.Sequence[EngineGenerationConfig] = None, + vocab_processors: typing.Optional[list[typing.Callable]] = None) \ -> FormattersLogitsProcessor: """ Create a formatter logits processor. + Args: + llm: The LLM. + formatter_builders: The formatter builders. + configs: The engine generation configurations. + vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None. + Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected. """ tokenizer = llm.get_tokenizer() - vocab = create_engine_vocabulary(llm) + vocab = create_engine_vocabulary(tokenizer, vocab_processors) if not isinstance(formatter_builders, collections.abc.Sequence): formatter_builders = [formatter_builders] formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None