From d0999dbbf9d1d34337dc97c973d3508d83b01213 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Mon, 18 Sep 2023 12:52:15 +0200 Subject: [PATCH 1/8] support `--clip_checkpoint` option for using a specific clip checkpoint --- clip_retrieval/clip_inference/main.py | 1 + clip_retrieval/clip_inference/mapper.py | 3 ++- clip_retrieval/clip_inference/worker.py | 4 +++- clip_retrieval/load_clip.py | 28 ++++++++++++++++++------- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/clip_retrieval/clip_inference/main.py b/clip_retrieval/clip_inference/main.py index ee74377..a8d0109 100644 --- a/clip_retrieval/clip_inference/main.py +++ b/clip_retrieval/clip_inference/main.py @@ -80,6 +80,7 @@ def main( wds_image_key="jpg", wds_caption_key="txt", clip_model="ViT-B/32", + clip_checkpoint=None, mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", use_mclip=False, use_jit=False, diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index 03df075..f638dd6 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -27,6 +27,7 @@ def __init__( mclip_model, warmup_batch_size=1, clip_cache_path=None, + checkpoint=None, ): self.enable_image = enable_image self.enable_text = enable_text @@ -34,7 +35,7 @@ def __init__( self.use_mclip = use_mclip self.device = "cuda" if torch.cuda.is_available() else "cpu" model, _ = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path + clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path, checkpoint=checkpoint, ) self.model_img = model.encode_image self.model_txt = model.encode_text diff --git a/clip_retrieval/clip_inference/worker.py b/clip_retrieval/clip_inference/worker.py index 1ccce4b..88611f3 100644 --- a/clip_retrieval/clip_inference/worker.py +++ b/clip_retrieval/clip_inference/worker.py @@ -34,6 +34,7 @@ def worker( wds_image_key="jpg", wds_caption_key="txt", clip_model="ViT-B/32", + clip_checkpoint=None, mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", use_mclip=False, use_jit=True, @@ -50,7 +51,7 @@ def worker( def reader_builder(sampler): _, preprocess = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path + clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path, checkpoint=clip_checkpoint, ) if input_format == "files": return FilesReader( @@ -91,6 +92,7 @@ def mapper_builder(): mclip_model=mclip_model, clip_cache_path=clip_cache_path, warmup_batch_size=batch_size, + checkpoint=clip_checkpoint, ) def writer_builder(i): diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index 1230cc6..ea080d1 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -79,15 +79,27 @@ def load_hf_clip(clip_model, device="cuda"): return model, lambda x: preprocess(x, return_tensors="pt").pixel_values -def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None): +def load_hf_clip(clip_model, device="cuda"): + """load hf clip""" + from transformers import CLIPProcessor, CLIPModel # pylint: disable=import-outside-toplevel + + model = CLIPModel.from_pretrained(clip_model) + preprocess = CLIPProcessor.from_pretrained(clip_model).image_processor + model = HFClipWrapper(inner_model=model, device=device) + model.to(device=device) + return model, lambda x: preprocess(x, return_tensors="pt").pixel_values + + +def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None, checkpoint=None): """load open clip""" import open_clip # pylint: disable=import-outside-toplevel torch.backends.cuda.matmul.allow_tf32 = True - - pretrained = dict(open_clip.list_pretrained()) - checkpoint = pretrained[clip_model] + + if checkpoint is None: + pretrained = dict(open_clip.list_pretrained()) + checkpoint = pretrained[clip_model] model, _, preprocess = open_clip.create_model_and_transforms( clip_model, pretrained=checkpoint, device=device, jit=use_jit, cache_dir=clip_cache_path ) @@ -184,11 +196,11 @@ def get_tokenizer(clip_model): @lru_cache(maxsize=None) -def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path): +def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=None): """Load clip""" if clip_model.startswith("open_clip:"): clip_model = clip_model[len("open_clip:") :] - model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path) + model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) elif clip_model.startswith("hf_clip:"): clip_model = clip_model[len("hf_clip:") :] model, preprocess = load_hf_clip(clip_model, device) @@ -201,11 +213,11 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path): @lru_cache(maxsize=None) -def load_clip(clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, clip_cache_path=None, device=None): +def load_clip(clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, clip_cache_path=None, device=None, checkpoint=None): """Load clip then warmup""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - model, preprocess = load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path) + model, preprocess = load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) start = time.time() print(f"warming up with batch size {warmup_batch_size} on {device}", flush=True) From cb65d4437c4880959ed161b704d64965f5eeb600 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Mon, 18 Sep 2023 13:12:07 +0200 Subject: [PATCH 2/8] linting --- clip_retrieval/clip_inference/mapper.py | 3 ++- clip_retrieval/clip_inference/worker.py | 3 ++- clip_retrieval/load_clip.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index f638dd6..7436207 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -35,7 +35,8 @@ def __init__( self.use_mclip = use_mclip self.device = "cuda" if torch.cuda.is_available() else "cpu" model, _ = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path, checkpoint=checkpoint, + clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, + clip_cache_path=clip_cache_path, checkpoint=checkpoint, ) self.model_img = model.encode_image self.model_txt = model.encode_text diff --git a/clip_retrieval/clip_inference/worker.py b/clip_retrieval/clip_inference/worker.py index 88611f3..53ba4e5 100644 --- a/clip_retrieval/clip_inference/worker.py +++ b/clip_retrieval/clip_inference/worker.py @@ -51,7 +51,8 @@ def worker( def reader_builder(sampler): _, preprocess = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path, checkpoint=clip_checkpoint, + clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, + clip_cache_path=clip_cache_path, checkpoint=clip_checkpoint, ) if input_format == "files": return FilesReader( diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index ea080d1..ab5377f 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -96,7 +96,6 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None import open_clip # pylint: disable=import-outside-toplevel torch.backends.cuda.matmul.allow_tf32 = True - if checkpoint is None: pretrained = dict(open_clip.list_pretrained()) checkpoint = pretrained[clip_model] @@ -213,7 +212,10 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, check @lru_cache(maxsize=None) -def load_clip(clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, clip_cache_path=None, device=None, checkpoint=None): +def load_clip( + clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, + clip_cache_path=None, device=None, checkpoint=None +): """Load clip then warmup""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" From a9181d9f1aa050ac9da04b58fa27fc88638169c3 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Mon, 18 Sep 2023 13:23:33 +0200 Subject: [PATCH 3/8] linting --- clip_retrieval/clip_inference/mapper.py | 2 +- clip_retrieval/clip_inference/worker.py | 2 +- clip_retrieval/load_clip.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index 7436207..21cf548 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -35,7 +35,7 @@ def __init__( self.use_mclip = use_mclip self.device = "cuda" if torch.cuda.is_available() else "cpu" model, _ = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, + clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path, checkpoint=checkpoint, ) self.model_img = model.encode_image diff --git a/clip_retrieval/clip_inference/worker.py b/clip_retrieval/clip_inference/worker.py index 53ba4e5..9302e1a 100644 --- a/clip_retrieval/clip_inference/worker.py +++ b/clip_retrieval/clip_inference/worker.py @@ -51,7 +51,7 @@ def worker( def reader_builder(sampler): _, preprocess = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, + clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path, checkpoint=clip_checkpoint, ) if input_format == "files": diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index ab5377f..a77658b 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -213,7 +213,7 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, check @lru_cache(maxsize=None) def load_clip( - clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, + clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, clip_cache_path=None, device=None, checkpoint=None ): """Load clip then warmup""" From 2d96a4ec61c657534ad1b25ccda3879c73d70536 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Mon, 18 Sep 2023 13:41:09 +0200 Subject: [PATCH 4/8] apply black for linting --- clip_retrieval/clip_inference/mapper.py | 7 +++-- clip_retrieval/clip_inference/worker.py | 7 +++-- clip_retrieval/load_clip.py | 42 +++++++++++++++++++------ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index 21cf548..a6cd1bf 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -35,8 +35,11 @@ def __init__( self.use_mclip = use_mclip self.device = "cuda" if torch.cuda.is_available() else "cpu" model, _ = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=warmup_batch_size, - clip_cache_path=clip_cache_path, checkpoint=checkpoint, + clip_model=clip_model, + use_jit=use_jit, + warmup_batch_size=warmup_batch_size, + clip_cache_path=clip_cache_path, + checkpoint=checkpoint, ) self.model_img = model.encode_image self.model_txt = model.encode_text diff --git a/clip_retrieval/clip_inference/worker.py b/clip_retrieval/clip_inference/worker.py index 9302e1a..d264b3b 100644 --- a/clip_retrieval/clip_inference/worker.py +++ b/clip_retrieval/clip_inference/worker.py @@ -51,8 +51,11 @@ def worker( def reader_builder(sampler): _, preprocess = load_clip( - clip_model=clip_model, use_jit=use_jit, warmup_batch_size=batch_size, - clip_cache_path=clip_cache_path, checkpoint=clip_checkpoint, + clip_model=clip_model, + use_jit=use_jit, + warmup_batch_size=batch_size, + clip_cache_path=clip_cache_path, + checkpoint=clip_checkpoint, ) if input_format == "files": return FilesReader( diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index a77658b..422ac2b 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -50,7 +50,9 @@ def __init__(self, inner_model, device): if self.device.type == "cpu": self.dtype = torch.float32 else: - self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + self.dtype = ( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) def encode_image(self, image): if self.device.type == "cpu": @@ -90,7 +92,9 @@ def load_hf_clip(clip_model, device="cuda"): return model, lambda x: preprocess(x, return_tensors="pt").pixel_values -def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None, checkpoint=None): +def load_open_clip( + clip_model, use_jit=True, device="cuda", clip_cache_path=None, checkpoint=None +): """load open clip""" import open_clip # pylint: disable=import-outside-toplevel @@ -100,7 +104,11 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None pretrained = dict(open_clip.list_pretrained()) checkpoint = pretrained[clip_model] model, _, preprocess = open_clip.create_model_and_transforms( - clip_model, pretrained=checkpoint, device=device, jit=use_jit, cache_dir=clip_cache_path + clip_model, + pretrained=checkpoint, + device=device, + jit=use_jit, + cache_dir=clip_cache_path, ) model = OpenClipWrapper(inner_model=model, device=device) model.to(device=device) @@ -195,11 +203,15 @@ def get_tokenizer(clip_model): @lru_cache(maxsize=None) -def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=None): +def load_clip_without_warmup( + clip_model, use_jit, device, clip_cache_path, checkpoint=None +): """Load clip""" if clip_model.startswith("open_clip:"): clip_model = clip_model[len("open_clip:") :] - model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) + model, preprocess = load_open_clip( + clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint + ) elif clip_model.startswith("hf_clip:"): clip_model = clip_model[len("hf_clip:") :] model, preprocess = load_hf_clip(clip_model, device) @@ -207,19 +219,27 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, check clip_model = clip_model[len("nm:") :] model, preprocess = load_deepsparse(clip_model) else: - model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path) + model, preprocess = clip.load( + clip_model, device=device, jit=use_jit, download_root=clip_cache_path + ) return model, preprocess @lru_cache(maxsize=None) def load_clip( - clip_model="ViT-B/32", use_jit=True, warmup_batch_size=1, - clip_cache_path=None, device=None, checkpoint=None + clip_model="ViT-B/32", + use_jit=True, + warmup_batch_size=1, + clip_cache_path=None, + device=None, + checkpoint=None, ): """Load clip then warmup""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - model, preprocess = load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) + model, preprocess = load_clip_without_warmup( + clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint + ) start = time.time() print(f"warming up with batch size {warmup_batch_size} on {device}", flush=True) @@ -232,7 +252,9 @@ def load_clip( def warmup(batch_size, device, preprocess, model): fake_img = Image.new("RGB", (224, 224), color="red") fake_text = ["fake"] * batch_size - image_tensor = torch.cat([torch.unsqueeze(preprocess(fake_img), 0)] * batch_size).to(device) + image_tensor = torch.cat( + [torch.unsqueeze(preprocess(fake_img), 0)] * batch_size + ).to(device) text_tokens = clip.tokenize(fake_text).to(device) for _ in range(2): with torch.no_grad(): From 6d255e1dedcef998ac395a7a90c3cb29361d1796 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Sun, 7 Jan 2024 00:16:58 +0100 Subject: [PATCH 5/8] fix lint --- clip_retrieval/load_clip.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index 422ac2b..0b7e869 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -50,9 +50,7 @@ def __init__(self, inner_model, device): if self.device.type == "cpu": self.dtype = torch.float32 else: - self.dtype = ( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ) + self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 def encode_image(self, image): if self.device.type == "cpu": @@ -92,9 +90,7 @@ def load_hf_clip(clip_model, device="cuda"): return model, lambda x: preprocess(x, return_tensors="pt").pixel_values -def load_open_clip( - clip_model, use_jit=True, device="cuda", clip_cache_path=None, checkpoint=None -): +def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None, checkpoint=None): """load open clip""" import open_clip # pylint: disable=import-outside-toplevel @@ -203,15 +199,11 @@ def get_tokenizer(clip_model): @lru_cache(maxsize=None) -def load_clip_without_warmup( - clip_model, use_jit, device, clip_cache_path, checkpoint=None -): +def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=None): """Load clip""" if clip_model.startswith("open_clip:"): clip_model = clip_model[len("open_clip:") :] - model, preprocess = load_open_clip( - clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint - ) + model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) elif clip_model.startswith("hf_clip:"): clip_model = clip_model[len("hf_clip:") :] model, preprocess = load_hf_clip(clip_model, device) @@ -219,9 +211,7 @@ def load_clip_without_warmup( clip_model = clip_model[len("nm:") :] model, preprocess = load_deepsparse(clip_model) else: - model, preprocess = clip.load( - clip_model, device=device, jit=use_jit, download_root=clip_cache_path - ) + model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path) return model, preprocess @@ -237,9 +227,7 @@ def load_clip( """Load clip then warmup""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - model, preprocess = load_clip_without_warmup( - clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint - ) + model, preprocess = load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) start = time.time() print(f"warming up with batch size {warmup_batch_size} on {device}", flush=True) @@ -252,9 +240,7 @@ def load_clip( def warmup(batch_size, device, preprocess, model): fake_img = Image.new("RGB", (224, 224), color="red") fake_text = ["fake"] * batch_size - image_tensor = torch.cat( - [torch.unsqueeze(preprocess(fake_img), 0)] * batch_size - ).to(device) + image_tensor = torch.cat([torch.unsqueeze(preprocess(fake_img), 0)] * batch_size).to(device) text_tokens = clip.tokenize(fake_text).to(device) for _ in range(2): with torch.no_grad(): From 2893ca2020a52371495e90f85d6a895fb44e30d6 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Sun, 7 Jan 2024 00:30:04 +0100 Subject: [PATCH 6/8] remove clip checkpoint arg, instead parse as / --- README.md | 4 ++-- clip_retrieval/clip_inference/main.py | 1 - clip_retrieval/clip_inference/mapper.py | 1 - clip_retrieval/clip_inference/worker.py | 3 --- clip_retrieval/load_clip.py | 14 ++++++++------ tests/test_clip_inference/test_mapper.py | 2 +- 6 files changed, 11 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 1204d60..72d7429 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ clip_inference turn a set of text+image into clip embeddings * **write_batch_size** Write batch size (default *10**6*) * **wds_image_key** Key to use for images in webdataset. (default *jpg*) * **wds_caption_key** Key to use for captions in webdataset. (default *txt*) -* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model. +* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32/laion2b_s34b_b79k"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model. * **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*) * **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*) * **use_jit** uses jit for the clip model (default *True*) @@ -183,7 +183,7 @@ clip_inference turn a set of text+image into clip embeddings * **slurm_partition** (default *None*), the slurm partition to create a job in. * **slurm_jobs**, the number of jobs to create in slurm. (default *None*) * **slurm_job_comment**, the job comment to use. (default *None*) -* **slurm_nodelist**, a list of specific nodes to use .(default *None* +* **slurm_nodelist**, a list of specific nodes to use .(default *None*) * **slurm_exclude**, a list of nodes to exclude when creating jobs. (default *None*) * **slurm_job_timeout**, if not supplied it will default to 2 weeks. (default *None*) * **slurm_cache_path**, cache path to use for slurm-related tasks. (default *None*) diff --git a/clip_retrieval/clip_inference/main.py b/clip_retrieval/clip_inference/main.py index a8d0109..ee74377 100644 --- a/clip_retrieval/clip_inference/main.py +++ b/clip_retrieval/clip_inference/main.py @@ -80,7 +80,6 @@ def main( wds_image_key="jpg", wds_caption_key="txt", clip_model="ViT-B/32", - clip_checkpoint=None, mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", use_mclip=False, use_jit=False, diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index a6cd1bf..79fe717 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -39,7 +39,6 @@ def __init__( use_jit=use_jit, warmup_batch_size=warmup_batch_size, clip_cache_path=clip_cache_path, - checkpoint=checkpoint, ) self.model_img = model.encode_image self.model_txt = model.encode_text diff --git a/clip_retrieval/clip_inference/worker.py b/clip_retrieval/clip_inference/worker.py index d264b3b..f608b6b 100644 --- a/clip_retrieval/clip_inference/worker.py +++ b/clip_retrieval/clip_inference/worker.py @@ -34,7 +34,6 @@ def worker( wds_image_key="jpg", wds_caption_key="txt", clip_model="ViT-B/32", - clip_checkpoint=None, mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", use_mclip=False, use_jit=True, @@ -55,7 +54,6 @@ def reader_builder(sampler): use_jit=use_jit, warmup_batch_size=batch_size, clip_cache_path=clip_cache_path, - checkpoint=clip_checkpoint, ) if input_format == "files": return FilesReader( @@ -96,7 +94,6 @@ def mapper_builder(): mclip_model=mclip_model, clip_cache_path=clip_cache_path, warmup_batch_size=batch_size, - checkpoint=clip_checkpoint, ) def writer_builder(i): diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index 0b7e869..3208b92 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -90,13 +90,16 @@ def load_hf_clip(clip_model, device="cuda"): return model, lambda x: preprocess(x, return_tensors="pt").pixel_values -def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None, checkpoint=None): +def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None): """load open clip""" import open_clip # pylint: disable=import-outside-toplevel torch.backends.cuda.matmul.allow_tf32 = True - if checkpoint is None: + clip_model_parts = clip_model.split("/") + clip_model = clip_model_parts[0] + checkpoint = "/".join(clip_model_parts[1:]) + if checkpoint == "": pretrained = dict(open_clip.list_pretrained()) checkpoint = pretrained[clip_model] model, _, preprocess = open_clip.create_model_and_transforms( @@ -199,11 +202,11 @@ def get_tokenizer(clip_model): @lru_cache(maxsize=None) -def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=None): +def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path): """Load clip""" if clip_model.startswith("open_clip:"): clip_model = clip_model[len("open_clip:") :] - model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) + model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path) elif clip_model.startswith("hf_clip:"): clip_model = clip_model[len("hf_clip:") :] model, preprocess = load_hf_clip(clip_model, device) @@ -222,12 +225,11 @@ def load_clip( warmup_batch_size=1, clip_cache_path=None, device=None, - checkpoint=None, ): """Load clip then warmup""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - model, preprocess = load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path, checkpoint=checkpoint) + model, preprocess = load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path) start = time.time() print(f"warming up with batch size {warmup_batch_size} on {device}", flush=True) diff --git a/tests/test_clip_inference/test_mapper.py b/tests/test_clip_inference/test_mapper.py index 3a0339a..9d48163 100644 --- a/tests/test_clip_inference/test_mapper.py +++ b/tests/test_clip_inference/test_mapper.py @@ -10,7 +10,7 @@ "model", [ "ViT-B/32", - "open_clip:ViT-B-32-quickgelu", + "open_clip:ViT-B-32/laion2b_s34b_b79k", "hf_clip:patrickjohncyh/fashion-clip", "nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds", ], From 2c6f84322cf4cf9165ac793b44c681eaf178639a Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Sun, 7 Jan 2024 00:33:26 +0100 Subject: [PATCH 7/8] remove additional load clip --- clip_retrieval/load_clip.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index 3208b92..a60443e 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -79,17 +79,6 @@ def load_hf_clip(clip_model, device="cuda"): return model, lambda x: preprocess(x, return_tensors="pt").pixel_values -def load_hf_clip(clip_model, device="cuda"): - """load hf clip""" - from transformers import CLIPProcessor, CLIPModel # pylint: disable=import-outside-toplevel - - model = CLIPModel.from_pretrained(clip_model) - preprocess = CLIPProcessor.from_pretrained(clip_model).image_processor - model = HFClipWrapper(inner_model=model, device=device) - model.to(device=device) - return model, lambda x: preprocess(x, return_tensors="pt").pixel_values - - def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None): """load open clip""" From 229f9ce23d69e98d0000f25656278ce7244ee6f1 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Sun, 7 Jan 2024 00:34:25 +0100 Subject: [PATCH 8/8] missing checkpoint --- clip_retrieval/clip_inference/mapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/clip_retrieval/clip_inference/mapper.py b/clip_retrieval/clip_inference/mapper.py index 79fe717..c555916 100644 --- a/clip_retrieval/clip_inference/mapper.py +++ b/clip_retrieval/clip_inference/mapper.py @@ -27,7 +27,6 @@ def __init__( mclip_model, warmup_batch_size=1, clip_cache_path=None, - checkpoint=None, ): self.enable_image = enable_image self.enable_text = enable_text