diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index cf62d5a1c..08076de13 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -74,7 +74,10 @@ def get_model_config(model_name): def _get_hf_config(model_id, cache_dir=None): - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + if cache_dir is None: + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + else: + config_path = os.path.join(cache_dir, model_id, 'open_clip_config.json') with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config @@ -83,12 +86,13 @@ def _get_hf_config(model_id, cache_dir=None): def get_tokenizer( model_name: str = '', context_length: Optional[int] = None, + cache_dir: Optional[str] = None, **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): model_name = model_name[len(HF_HUB_PREFIX):] try: - config = _get_hf_config(model_name)['model_cfg'] + config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg'] except Exception: tokenizer = HFTokenizer( model_name, @@ -185,7 +189,10 @@ def create_model( has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] - checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + if cache_dir is None: + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + else: + checkpoint_path = os.path.join(cache_dir, model_id, "open_clip_pytorch_model.bin") config = _get_hf_config(model_id, cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg']