diff --git a/README.md b/README.md index 1204d60e..72d74299 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ clip_inference turn a set of text+image into clip embeddings * **write_batch_size** Write batch size (default *10**6*) * **wds_image_key** Key to use for images in webdataset. (default *jpg*) * **wds_caption_key** Key to use for captions in webdataset. (default *txt*) -* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model. +* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32/laion2b_s34b_b79k"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model. * **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*) * **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*) * **use_jit** uses jit for the clip model (default *True*) @@ -183,7 +183,7 @@ clip_inference turn a set of text+image into clip embeddings * **slurm_partition** (default *None*), the slurm partition to create a job in. * **slurm_jobs**, the number of jobs to create in slurm. (default *None*) * **slurm_job_comment**, the job comment to use. (default *None*) -* **slurm_nodelist**, a list of specific nodes to use .(default *None* +* **slurm_nodelist**, a list of specific nodes to use .(default *None*) * **slurm_exclude**, a list of nodes to exclude when creating jobs. (default *None*) * **slurm_job_timeout**, if not supplied it will default to 2 weeks. (default *None*) * **slurm_cache_path**, cache path to use for slurm-related tasks. (default *None*) diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index 03df075c..c555916b 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -34,7 +34,10 @@ def __init__( self.use_mclip = use_mclip self.device = "cuda" if torch.cuda.is_available() else "cpu" model, _ = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path + clip_model=clip_model, + use_jit=use_jit, + warmup_batch_size=warmup_batch_size, + clip_cache_path=clip_cache_path, ) self.model_img = model.encode_image self.model_txt = model.encode_text diff --git a/clip_retrieval/clip_inference/worker.py b/clip_retrieval/clip_inference/worker.py index 1ccce4b4..f608b6b4 100644 --- a/clip_retrieval/clip_inference/worker.py +++ b/clip_retrieval/clip_inference/worker.py @@ -50,7 +50,10 @@ def worker( def reader_builder(sampler): _, preprocess = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path + clip_model=clip_model, + use_jit=use_jit, + warmup_batch_size=batch_size, + clip_cache_path=clip_cache_path, ) if input_format == "files": return FilesReader( diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index 1230cc61..a60443eb 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -85,11 +85,18 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None import open_clip # pylint: disable=import-outside-toplevel torch.backends.cuda.matmul.allow_tf32 = True - - pretrained = dict(open_clip.list_pretrained()) - checkpoint = pretrained[clip_model] + clip_model_parts = clip_model.split("/") + clip_model = clip_model_parts[0] + checkpoint = "/".join(clip_model_parts[1:]) + if checkpoint == "": + pretrained = dict(open_clip.list_pretrained()) + checkpoint = pretrained[clip_model] model, _, preprocess = open_clip.create_model_and_transforms( - clip_model, pretrained=checkpoint, device=device, jit=use_jit, cache_dir=clip_cache_path + clip_model, + pretrained=checkpoint, + device=device, + jit=use_jit, + cache_dir=clip_cache_path, ) model = OpenClipWrapper(inner_model=model, device=device) model.to(device=device) @@ -201,7 +208,13 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path): @lru_cache(maxsize=None) -def load_clip(clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, clip_cache_path=None, device=None): +def load_clip( + clip_model="ViT-B/32", + use_jit=True, + warmup_batch_size=1, + clip_cache_path=None, + device=None, +): """Load clip then warmup""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_clip_inference/test_mapper.py b/tests/test_clip_inference/test_mapper.py index 3a0339a2..9d48163f 100644 --- a/tests/test_clip_inference/test_mapper.py +++ b/tests/test_clip_inference/test_mapper.py @@ -10,7 +10,7 @@ "model", [ "ViT-B/32", - "open_clip:ViT-B-32-quickgelu", + "open_clip:ViT-B-32/laion2b_s34b_b79k", "hf_clip:patrickjohncyh/fashion-clip", "nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds", ],