Skip to content

Commit

Permalink
update get_tokenizer to pass CI test
Browse files Browse the repository at this point in the history
  • Loading branch information
zw615 committed Oct 3, 2023
1 parent 8f187ae commit bc20118
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def get_model_config(model_name):
def get_tokenizer(model_name):
if model_name.startswith(HF_HUB_PREFIX):
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
if 'context_length' in get_model_config(model_name[len(HF_HUB_PREFIX):])['text_cfg'].keys():
context_length = get_model_config(model_name[len(HF_HUB_PREFIX):])['text_cfg']['context_length']
tokenizer = partial(tokenizer, context_length=context_length)
else:
config = get_model_config(model_name)
if 'hf_tokenizer_name' in config['text_cfg']:
Expand All @@ -90,8 +93,8 @@ def get_tokenizer(model_name):
tokenizer = block_mask_tokenize
else:
tokenizer = tokenize
context_length = get_model_config(model_name)['text_cfg']['context_length']
tokenizer = partial(tokenizer, context_length=context_length)
context_length = get_model_config(model_name)['text_cfg']['context_length']
tokenizer = partial(tokenizer, context_length=context_length)
return tokenizer


Expand Down

0 comments on commit bc20118

Please sign in to comment.