Skip to content

Commit

Permalink
Fix the marian tokenizer importer. (#2426)
Browse files Browse the repository at this point in the history
* Fix the marian tokenizer importer.

* Ignore the python caches.
  • Loading branch information
LaurentMazare authored Aug 17, 2024
1 parent c1b9e07 commit b75ef05
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
20 changes: 16 additions & 4 deletions candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def import_protobuf(error_message=""):
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))

def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
if add_prefix_space:
prepend_scheme = "always"
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
prepend_scheme = "first"
else:
prepend_scheme = "never"
return prepend_scheme

class SentencePieceExtractor:
"""
Expand Down Expand Up @@ -519,13 +527,15 @@ def normalizer(self, proto):
)

def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)

def post_processor(self):
return None

def decoder(self, replacement, add_prefix_space):
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)

def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
Expand Down Expand Up @@ -636,7 +646,8 @@ def pre_tokenizer(self, replacement, add_prefix_space):
list_pretokenizers = []
if self.original_tokenizer.split_by_punct:
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
return pre_tokenizers.Sequence(list_pretokenizers)

def normalizer(self, proto):
Expand Down Expand Up @@ -929,10 +940,11 @@ def unk_id(self, proto):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset

def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
]
)

Expand Down

0 comments on commit b75ef05

Please sign in to comment.