Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LM generalist scripts #822

Draft
wants to merge 19 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,15 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Additional stuff to avoid tracking.
# Torch-em stuff
checkpoints/
logs/

# "gpu_jobs" folder where slurm batch submission scripts are saved
gpu_jobs/
sam_embeddings/
results/
iterative_prompting_results/
*.sh
*.png
6 changes: 0 additions & 6 deletions finetuning/.gitignore

This file was deleted.

239 changes: 182 additions & 57 deletions finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model

from obtain_lm_datasets import get_generalist_lm_loaders


def finetune_lm_generalist(args):
"""Code for finetuning SAM on multiple light microscopy datasets"""
"""Code for finetuning SAM on multiple Light Microscopy datasets."""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -22,92 +19,53 @@ def finetune_lm_generalist(args):
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (512, 512) # the patch shape for training
n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled (default: 25)
freeze_parts = None # override this to freeze one or more of these backbones

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
use_conv_transpose=True
)
unetr.to(device)

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = [params for params in model.parameters()] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)
checkpoint_name = f"{args.model_type}/lm_generalist_sam"

# all the stuff we need for training
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
train_loader, val_loader = get_generalist_lm_loaders(input_path=args.input_path, patch_shape=patch_shape)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 5, "verbose": True}

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

checkpoint_name = f"{args.model_type}/lm_generalist_sam"

# the trainer which performs the joint training and validation (implemented using "torch_em")
trainer = sam_training.JointSamTrainer(
# Run training.
sam_training.train_sam(
name=checkpoint_name,
save_root=args.save_root,
model_type=model_type,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
early_stopping=None, # NOTE: Avoid early stopping for training the generalist model.
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
checkpoint_path=checkpoint_path,
with_segmentation_decoder=True,
device=device,
lr=1e-5,
n_iterations=args.iterations,
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
verify_n_labels_in_loader=None, # NOTE: Verifies all labels in the loader(s).
box_distortion_factor=0.05,
load_weights=(not args.from_scratch),
)
trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch)

if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LM datasets.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/",
"--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam/data",
help="The filepath to all the respective LM datasets. If the data does not exist yet it will be downloaded"
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h."
help="The model type to use for fine-tuning. Either 'vit_t', 'vit_b', 'vit_l' or 'vit_h'."
)
parser.add_argument(
"--save_root", "-s",
"--save_root", "-s", default=None,
help="Where to save the checkpoint and logs. By default they will be saved where this script is run from."
)
parser.add_argument(
Expand All @@ -119,11 +77,10 @@ def main():
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
parser.add_argument(
"--save_every_kth_epoch", type=int, default=None,
help="To save every kth epoch while fine-tuning. Expects an integer value."
"--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
)
parser.add_argument(
"--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
"--from_scratch", action="store_true", help="Whether to train Segment Anything model from scratch."
)
args = parser.parse_args()
finetune_lm_generalist(args)
Expand Down
64 changes: 35 additions & 29 deletions finetuning/run_all_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,46 @@


N_OBJECTS = {
"vit_t": 50,
# "vit_t": 50,
"vit_b": 40,
"vit_l": 30,
"vit_h": 25
# "vit_h": 25,
}


def write_batch_script(out_path, _name, env_name, model_type, save_root):
def write_batch_script(out_path, _name, env_name, model_type, save_root, dry):
"Writing scripts with different micro-sam finetunings."
batch_script = f"""#!/bin/bash
#SBATCH -t 14-00:00:00
#SBATCH --mem 64G
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH -p grete-h100:shared
#SBATCH -G H100:1
#SBATCH -A gzz0001
#SBATCH -c 16
#SBATCH --mem 64G
#SBATCH --qos=14d
#SBATCH --constraint=80gb
#SBATCH --job-name={os.path.split(_name)[-1]}

source activate {env_name} \n"""
source ~/.bashrc
micromamba activate {env_name} \n"""

# python script
# The python script
python_script = f"python {_name}.py "
python_script += f"-s {save_root} " # The save root folder
python_script += f"-m {model_type} " # The name of the model configuration
python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " # The choice of the number of objects

# save root folder
python_script += f"-s {save_root} "

# name of the model configuration
python_script += f"-m {model_type} "

# choice of the number of objects
python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} "

# let's add the python script to the bash script
# Add the python script to the bash script
batch_script += python_script

_op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh"
with open(_op, "w") as f:
f.write(batch_script)

cmd = ["sbatch", _op]
subprocess.run(cmd)
if not dry:
subprocess.run(cmd)


def get_batch_script_names(tmp_folder):
Expand Down Expand Up @@ -101,26 +96,37 @@ def submit_slurm(args):
write_batch_script(
out_path=get_batch_script_names(tmp_folder),
_name=script_name,
env_name="mobilesam" if model_type == "vit_t" else "sam",
env_name="mobilesam" if model_type == "vit_t" else "super",
model_type=model_type,
save_root=args.save_root
save_root=args.save_root,
dry=args.dry,
)


def main(args):
try:
tmp_dir = "./gpu_jobs"
if os.path.exists(tmp_dir):
shutil.rmtree("./gpu_jobs")
except FileNotFoundError:
pass

submit_slurm(args)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--experiment_name", type=str, default=None)
parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/micro-sam/")
parser.add_argument("-m", "--model_type", type=str, default=None)
parser.add_argument(
"-e", "--experiment_name", type=str, default=None, help="The choice of experiment name.",
)
parser.add_argument(
"-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam",
help="The path where to store the model checkpoints and logs.",
)
parser.add_argument(
"-m", "--model_type", type=str, default=None, help="The choice of model type for Segment Anything model."
)
parser.add_argument(
"--dry", action="store_true", help="Whether to submit the scripts to slurm or only store the scripts."
)
args = parser.parse_args()

main(args)
9 changes: 7 additions & 2 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def train_sam(
peft_kwargs: Optional[Dict] = None,
ignore_warnings: bool = True,
verify_n_labels_in_loader: Optional[int] = 50,
box_distortion_factor: Optional[float] = 0.025,
load_weights: bool = True,
**model_kwargs,
) -> None:
"""Run training for a SAM model.
Expand Down Expand Up @@ -228,8 +230,10 @@ def train_sam(
peft_kwargs: Keyword arguments for the PEFT wrapper class.
verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
By default, 50 batches of labels are verified from the dataloaders.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
ignore_warnings: Whether to ignore raised warnings.
box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
load_weights: Whether to initialize the model with pretrained parameter weights.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
"""
with _filter_warnings(ignore_warnings):

Expand All @@ -247,11 +251,12 @@ def train_sam(
checkpoint_path=checkpoint_path,
return_state=True,
peft_kwargs=peft_kwargs,
load_weights=load_weights,
**model_kwargs
)

# This class creates all the training data for a batch (inputs, prompts and labels).
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)

# Create the UNETR decoder (if train with it) and the optimizer.
if with_segmentation_decoder:
Expand Down
32 changes: 18 additions & 14 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torch_em.transform.label import PerObjectDistanceTransform
from torch_em.transform.raw import normalize_percentile, normalize
from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb


def identity(x):
Expand Down Expand Up @@ -46,21 +47,22 @@ def get_trainable_sam_model(
return_state: bool = False,
peft_kwargs: Optional[Dict] = None,
flexible_load_checkpoint: bool = False,
load_weights: bool = True,
**model_kwargs
) -> TrainableSAM:
"""Get the trainable sam model.

Args:
model_type: The segment anything model that should be finetuned.
The weights of this model will be used for initialization, unless a
custom weight file is passed via `checkpoint_path`.
model_type: The segment anything model that should be finetuned. The weights of this model
will be used for initialization, unless a custom weight file is passed via `checkpoint_path`.
device: The device to use for training.
checkpoint_path: Path to a custom checkpoint from which to load the model weights.
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and
`mask_decoder`. By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
peft_kwargs: Keyword arguments for the PEFT wrapper class.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
load_weights: Whether to initialize the model with pretrained parameter weights.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.

Returns:
Expand All @@ -75,6 +77,7 @@ def get_trainable_sam_model(
return_sam=True,
return_state=True,
flexible_load_checkpoint=flexible_load_checkpoint,
load_weights=load_weights,
**model_kwargs
)

Expand Down Expand Up @@ -258,19 +261,20 @@ def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
self.do_rescaling = do_rescaling

def __call__(self, raw):
raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels.

if self.do_rescaling:
raw = normalize_percentile(raw, axis=(1, 2))
raw = np.mean(raw, axis=0)
# NOTE: Below is done for TissueNet: to work with the valid channels (i.e. the first and second channels).
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
raw = normalize_percentile(raw, axis=(0, 1))
raw = normalize(raw)
raw = raw * 255

tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1])
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
raw = np.pad(
raw,
pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
mode=self.padding
)
# Pad the inputs to the desired shape.
tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)]
ddim = [(per_dim / 2) for per_dim in tmp_ddim]
pad_width = [(ceil(d), floor(d)) for d in ddim]
raw = np.pad(raw, pad_width=pad_width, mode=self.padding)

assert raw.shape == self.desired_shape
return raw

Expand Down
Loading
Loading