Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the full open clip model name format : ViT-B-32/laion2b_s34b_b79k #314

Merged
merged 8 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ clip_inference turn a set of text+image into clip embeddings
* **write_batch_size** Write batch size (default *10**6*)
* **wds_image_key** Key to use for images in webdataset. (default *jpg*)
* **wds_caption_key** Key to use for captions in webdataset. (default *txt*)
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model.
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32/laion2b_s34b_b79k"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model.
* **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*)
* **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*)
* **use_jit** uses jit for the clip model (default *True*)
Expand All @@ -183,7 +183,7 @@ clip_inference turn a set of text+image into clip embeddings
* **slurm_partition** (default *None*), the slurm partition to create a job in.
* **slurm_jobs**, the number of jobs to create in slurm. (default *None*)
* **slurm_job_comment**, the job comment to use. (default *None*)
* **slurm_nodelist**, a list of specific nodes to use .(default *None*
* **slurm_nodelist**, a list of specific nodes to use .(default *None*)
* **slurm_exclude**, a list of nodes to exclude when creating jobs. (default *None*)
* **slurm_job_timeout**, if not supplied it will default to 2 weeks. (default *None*)
* **slurm_cache_path**, cache path to use for slurm-related tasks. (default *None*)
Expand Down
5 changes: 4 additions & 1 deletion clip_retrieval/clip_inference/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(
self.use_mclip = use_mclip
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = load_clip(
clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path
clip_model=clip_model,
use_jit=use_jit,
warmup_batch_size=warmup_batch_size,
clip_cache_path=clip_cache_path,
)
self.model_img = model.encode_image
self.model_txt = model.encode_text
Expand Down
5 changes: 4 additions & 1 deletion clip_retrieval/clip_inference/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def worker(

def reader_builder(sampler):
_, preprocess = load_clip(
clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path
clip_model=clip_model,
use_jit=use_jit,
warmup_batch_size=batch_size,
clip_cache_path=clip_cache_path,
)
if input_format == "files":
return FilesReader(
Expand Down
23 changes: 18 additions & 5 deletions clip_retrieval/load_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,18 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None
import open_clip # pylint: disable=import-outside-toplevel

torch.backends.cuda.matmul.allow_tf32 = True

pretrained = dict(open_clip.list_pretrained())
checkpoint = pretrained[clip_model]
clip_model_parts = clip_model.split("/")
clip_model = clip_model_parts[0]
checkpoint = "/".join(clip_model_parts[1:])
if checkpoint == "":
pretrained = dict(open_clip.list_pretrained())
checkpoint = pretrained[clip_model]
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model, pretrained=checkpoint, device=device, jit=use_jit, cache_dir=clip_cache_path
clip_model,
pretrained=checkpoint,
device=device,
jit=use_jit,
cache_dir=clip_cache_path,
)
model = OpenClipWrapper(inner_model=model, device=device)
model.to(device=device)
Expand Down Expand Up @@ -201,7 +208,13 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path):


@lru_cache(maxsize=None)
def load_clip(clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, clip_cache_path=None, device=None):
def load_clip(
clip_model="ViT-B/32",
use_jit=True,
warmup_batch_size=1,
clip_cache_path=None,
device=None,
):
"""Load clip then warmup"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_clip_inference/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"model",
[
"ViT-B/32",
"open_clip:ViT-B-32-quickgelu",
"open_clip:ViT-B-32/laion2b_s34b_b79k",
"hf_clip:patrickjohncyh/fashion-clip",
"nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds",
],
Expand Down
Loading