diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 12f3dec30..7268522e4 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -206,7 +206,7 @@ def create_model( if custom_text: if is_hf_model: model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf - if "coca" in model_name: + if "multimodal_cfg" in model_cfg: model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) else: model = CustomTextCLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)