diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 4637a82b9..48b45c6aa 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -78,6 +78,9 @@ def get_model_config(model_name): def get_tokenizer(model_name): if model_name.startswith(HF_HUB_PREFIX): tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + if 'context_length' in get_model_config(model_name[len(HF_HUB_PREFIX):])['text_cfg'].keys(): + context_length = get_model_config(model_name[len(HF_HUB_PREFIX):])['text_cfg']['context_length'] + tokenizer = partial(tokenizer, context_length=context_length) else: config = get_model_config(model_name) if 'hf_tokenizer_name' in config['text_cfg']: @@ -90,8 +93,8 @@ def get_tokenizer(model_name): tokenizer = block_mask_tokenize else: tokenizer = tokenize - context_length = get_model_config(model_name)['text_cfg']['context_length'] - tokenizer = partial(tokenizer, context_length=context_length) + context_length = get_model_config(model_name)['text_cfg']['context_length'] + tokenizer = partial(tokenizer, context_length=context_length) return tokenizer