Skip to content

Commit

Permalink
Map location to cpu to save cuda memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
zmgong committed Nov 19, 2024
1 parent eaec598 commit a1ba974
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
batch_size: 490
batch_size: 500
epochs: 30
labels_for_driven_positive_and_negative_pairs:
wandb_project_name: BIOSCAN-CLIP-small_experiments
Expand Down
7 changes: 1 addition & 6 deletions bioscanclip/model/simple_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def load_clip_model(args, device=None):
pre_trained_timm_vit = timm.create_model(image_model_name, pretrained=True)
if hasattr(args.model_config.image, 'image_encoder_trained_with_simclr_style_ckpt_path'):
image_encoder_trained_with_simclr_style_ckpt_path = args.model_config.image.image_encoder_trained_with_simclr_style_ckpt_path
checkpoint = torch.load(image_encoder_trained_with_simclr_style_ckpt_path)
checkpoint = torch.load(image_encoder_trained_with_simclr_style_ckpt_path, map_location='cpu')
state_dict = checkpoint['state_dict']
state_dict = remove_module_from_state_dict(state_dict)
pre_trained_timm_vit.load_state_dict(state_dict)
Expand All @@ -219,11 +219,6 @@ def load_clip_model(args, device=None):
del checkpoint
torch.cuda.empty_cache()
print("Loaded image encoder from %s" % image_encoder_trained_with_simclr_style_ckpt_path)
# Check the memory usage of the image encoder

print(torch.cuda.mem_get_info())
exit()
# pre_trained_timm_vit = timm.create_model('vit_base_patch16_224', pretrained=True)
if disable_lora:
image_encoder = LoRA_ViT_timm(vit_model=pre_trained_timm_vit, r=4,
num_classes=args.model_config.output_dim, lora_layer=[])
Expand Down

0 comments on commit a1ba974

Please sign in to comment.