From 3ad01485399982faa440ff3ea3909ce098f602e9 Mon Sep 17 00:00:00 2001 From: zmgong Date: Wed, 24 Jul 2024 15:17:13 -0700 Subject: [PATCH] Using the ckpt that trained with contrastive learning on the INSECT dataset if available. --- .../supervised_fine_tune_bioscan_clip_model_on_insect.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/supervised_fine_tune_bioscan_clip_model_on_insect.py b/scripts/supervised_fine_tune_bioscan_clip_model_on_insect.py index 82b2a01..1a7d1b7 100644 --- a/scripts/supervised_fine_tune_bioscan_clip_model_on_insect.py +++ b/scripts/supervised_fine_tune_bioscan_clip_model_on_insect.py @@ -56,9 +56,11 @@ def main(args: DictConfig) -> None: unique_species_for_seen = get_unique_species_for_seen(insect_trainval_dataloader) print("Load model...") - original_model = load_clip_model(args) - checkpoint = torch.load(args.model_config.ckpt_path, map_location='cuda:0') + if hasattr(args.model_config, 'ckpt_trained_with_insect_image_dna_text_path'): + checkpoint = torch.load(args.model_config.ckpt_trained_with_insect_image_dna_text_path, map_location='cuda:0') + else: + checkpoint = torch.load(args.model_config.ckpt_path, map_location='cuda:0') original_model.load_state_dict(checkpoint) original_model = original_model.to(device)