diff --git a/.gitignore b/.gitignore index 734e5772..cbb13274 100644 --- a/.gitignore +++ b/.gitignore @@ -174,9 +174,15 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# Additional stuff to avoid tracking. # Torch-em stuff checkpoints/ logs/ # "gpu_jobs" folder where slurm batch submission scripts are saved gpu_jobs/ +sam_embeddings/ +results/ +iterative_prompting_results/ +*.sh +*.png diff --git a/finetuning/.gitignore b/finetuning/.gitignore deleted file mode 100644 index b078b91d..00000000 --- a/finetuning/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -checkpoints/ -logs/ -sam_embeddings/ -results/ -iterative_prompting_results/ -*.sh diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 8ac629b4..f20eefde 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -1,124 +1,250 @@ import os + import numpy as np +from sklearn.model_selection import train_test_split import torch + import torch_em -import torch_em.data.datasets as datasets +from torch_em.data import datasets +from torch_em.transform.raw import normalize from torch_em.data import MinInstanceSampler, ConcatDataset from torch_em.transform.label import PerObjectDistanceTransform -from torch_em.transform.raw import normalize_percentile, normalize +from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb -from micro_sam.training import identity from micro_sam.training.util import ResizeRawTrafo, ResizeLabelTrafo -def neurips_raw_trafo(raw): - raw = datasets.neurips_cell_seg.to_rgb(raw) # ensures 3 channels for the neurips data - raw = normalize_percentile(raw) - raw = np.mean(raw, axis=0) - raw = normalize(raw) - raw = raw * 255 +def _to_8bit(raw): + "Ensures three channels for inputs and rescale them to 8 bit." + if raw.ndim == 3 and raw.shape[0] == 1: # If the inputs have 1 channel, we triplicate it. + raw = np.concatenate([raw] * 3, axis=0) + + raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. + raw = normalize(raw) * 255 return raw -def to_8bit(raw): - raw = normalize(raw) - raw = raw * 255 - return raw +def _identity(x): + "Ensures three channels for inputs and avoids rescaling inputs." + x = to_rgb(x) + return x + + +def _cellpose_raw_trafo(x): + """Transforms input images to desired format. + + NOTE: The input channel logic is arranged a bit strangely in `cyto` dataset. + This function takes care of it here. + """ + r, g, b = x + + assert g.max() != 0 + if r.max() == 0: + # The image is 1 channel and exists in green channel only. + assert b.max() == 0 + x = np.concatenate([g[None]] * 3, axis=0) + + elif r.max() != 0 and g.max() != 0: + # The image is 2 channels and we sort the channels such that - 0: cell, 1: nucleus + x = np.stack([g, r, np.zeros_like(b)], axis=0) + + x = to_rgb(x) # Ensures three channels for inputs and avoids rescaling inputs. + + return x def get_concat_lm_datasets(input_path, patch_shape, split_choice): assert split_choice in ["train", "val"] label_dtype = torch.float32 - sampler = MinInstanceSampler() + sampler = MinInstanceSampler(min_size=10) - def get_label_transform(min_size=0): + def _get_label_transform(): label_transform = PerObjectDistanceTransform( - distances=True, boundary_distances=True, directed_distances=False, - foreground=True, instances=True, min_size=min_size + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, ) return label_transform - def get_ctc_datasets( - input_path, patch_shape, sampler, raw_transform, label_transform, - ignore_datasets=["Fluo-N2DH-GOWT1", "Fluo-N2DL-HeLa"] - ): + def get_livecell_datasets(): + "Datasets for cell segmentation in phase contrast microscopy images." + all_livecell_datasets = [ + datasets.get_livecell_dataset( + path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, + sampler=sampler, label_dtype=label_dtype, raw_transform=_identity, download=True, cell_types=[ctype], + label_transform=_get_label_transform(), n_samples=200 if split_choice == "train" else 50, + ) for ctype in datasets.livecell.CELL_TYPES + ] + return all_livecell_datasets + + def get_embedseg_datasets(): + "Datasets for cell and nuclei segmentation in fluorescence microscopy images." + names = [ + "Mouse-Organoid-Cells-CBG", + "Mouse-Skull-Nuclei-CBG", + "Platynereis-ISH-Nuclei-CBG", + "Platynereis-Nuclei-CBG", + ] + all_embedseg_datasets = [ + datasets.get_embedseg_dataset( + path=os.path.join(input_path, "embedseg"), name=name, patch_shape=(1, *patch_shape), + download=True, n_samples=500 if split_choice == "train" else 100, raw_transform=_to_8bit, + label_dtype=label_dtype, label_transform=_get_label_transform(), ndim=2, + sampler=MinInstanceSampler(min_num_instances=3, min_size=10), + ) for name in names + ] + return all_embedseg_datasets + + def get_yeaz_dataset(): + "Datasets for yeast segmentation in phase contrast and brightfield microscopy images." + names = ["bf", "phc"] + all_yeaz_datasets = [ + datasets.get_yeaz_dataset( + path=os.path.join(input_path, "yeaz"), patch_shape=patch_shape, raw_transform=_to_8bit, + ndim=2, download=True, split=split_choice, choice=name, label_transform=_get_label_transform(), + sampler=sampler, label_dtype=label_dtype, + ) for name in names + ] + return all_yeaz_datasets + + def get_cvz_dataset(stain_choice): + "Datasets for cell and nuclei segmentation in fluorescence microscopy images." + # NOTE: We create random splits for this dataset for training the generalist. + raw_paths, label_paths = datasets.cvz_fluo.get_cvz_fluo_paths( + path=os.path.join(input_path, "cvz"), stain_choice=stain_choice, + ) + train_raw_paths, test_raw_paths, train_label_paths, test_label_paths = train_test_split( + raw_paths, label_paths, test_size=0.2, random_state=42, + ) + ds = torch_em.default_segmentation_dataset( + raw_paths=train_raw_paths if split_choice == "train" else test_raw_paths, raw_key=None, + label_paths=train_label_paths if split_choice == "train" else test_label_paths, + label_key=None, is_seg_dataset=False, patch_shape=patch_shape, sampler=sampler, + raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, + ) + return [ds] + + def get_ctc_datasets(): + "Datasets for cell segmentation in different modalities." all_ctc_datasets = [] for dataset_name in datasets.ctc.CTC_CHECKSUMS["train"].keys(): - if dataset_name in ignore_datasets: + if dataset_name in ["Fluo-N2DH-GOWT1", "Fluo-N2DL-HeLa"]: continue all_ctc_datasets.append( datasets.get_ctc_segmentation_dataset( path=os.path.join(input_path, "ctc"), dataset_name=dataset_name, patch_shape=(1, *patch_shape), - sampler=sampler, raw_transform=raw_transform, label_transform=label_transform, download=True + sampler=sampler, raw_transform=_to_8bit, label_transform=_get_label_transform(), + download=True, label_dtype=label_dtype, ) ) return all_ctc_datasets _datasets = [ + # cell segmentation in tissue microscopy images. datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, - raw_channel="rgb", label_channel="cell", sampler=sampler, label_dtype=label_dtype, - raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=True), - label_transform=ResizeLabelTrafo(patch_shape, min_size=0), - n_samples=1000 if split_choice == "train" else 100 - ), - datasets.get_livecell_dataset( - path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, - download=True, label_transform=get_label_transform(), sampler=sampler, - label_dtype=label_dtype, raw_transform=identity + raw_channel="rgb", label_channel="cell", label_transform=ResizeLabelTrafo(patch_shape), + # NOTE: Below is done for TissueNet: to work with the valid channels (i.e. the first and second channels). + raw_transform=ResizeRawTrafo((3, *patch_shape), do_rescaling=True, valid_channels=(0, 1)), + n_samples=500 if split_choice == "train" else 100, sampler=sampler, label_dtype=label_dtype, ), + # bacteria segmentation in label-free microscopy images. datasets.get_deepbacs_dataset( path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, - raw_transform=to_8bit, label_transform=get_label_transform(), label_dtype=label_dtype, - download=True, sampler=MinInstanceSampler(min_num_instances=4) + raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, + download=True, sampler=MinInstanceSampler(min_num_instances=4, min_size=10) + ), + # cell segmentation in confocal microscopy images. + datasets.get_plantseg_dataset( + path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, + patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), + sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, + label_transform=ResizeLabelTrafo(patch_shape), ), + # cell segmentation in multi-modal microscopy images. datasets.get_neurips_cellseg_supervised_dataset( - root=os.path.join(input_path, "neurips-cell-seg"), split=split_choice, - patch_shape=patch_shape, raw_transform=neurips_raw_trafo, label_transform=get_label_transform(), - label_dtype=label_dtype, sampler=MinInstanceSampler(min_num_instances=3) + root=os.path.join(input_path, "neurips_cellseg"), split=split_choice, label_dtype=label_dtype, + patch_shape=patch_shape, raw_transform=_to_8bit, label_transform=_get_label_transform(), + sampler=MinInstanceSampler(min_num_instances=3, min_size=10), download=True, ), + # nuclei segmentation in fluorescence microscopy images. datasets.get_dsb_dataset( path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", - patch_shape=patch_shape, label_transform=get_label_transform(), sampler=sampler, - label_dtype=label_dtype, download=True, raw_transform=identity + patch_shape=patch_shape, label_transform=_get_label_transform(), sampler=sampler, + label_dtype=label_dtype, download=True, raw_transform=_identity, ), - datasets.get_plantseg_dataset( - path=os.path.join(input_path, "plantseg"), name="root", sampler=MinInstanceSampler(min_num_instances=10), - patch_shape=(1, *patch_shape), download=True, split=split_choice, ndim=2, label_dtype=label_dtype, - raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=False), - label_transform=ResizeLabelTrafo(patch_shape, min_size=0), - n_samples=1000 if split_choice == "train" else 100 + # nuclei segmentation in fluorescence microscopy images. + datasets.get_dynamicnuclearnet_dataset( + path=os.path.join(input_path, "dynamicnuclearnet"), patch_shape=patch_shape, download=True, sampler=sampler, + split=split_choice, n_samples=500 if split_choice == "train" else 100, label_dtype=label_dtype, + raw_transform=_to_8bit, label_transform=_get_label_transform(), + ), + # cell segmentation in multiple microscopy imaging modalities. + datasets.get_cellpose_dataset( + path=os.path.join(input_path, "cellpose"), patch_shape=patch_shape, choice="cyto", sampler=sampler, + download=True, split=split_choice if split_choice == "train" else "test", label_dtype=label_dtype, + label_transform=_get_label_transform(), raw_transform=_cellpose_raw_trafo, + ), + # bacteria segmentation in phase contrast and fluorescence microscopy images. + # worm segmentation in brightfield microscopy images. + datasets.get_omnipose_dataset( + path=os.path.join(input_path, "omnipose"), patch_shape=patch_shape, download=True, + split=split_choice if split_choice == "train" else "test", sampler=sampler, + label_dtype=label_dtype, raw_transform=_to_8bit, label_transform=_get_label_transform(), + ), + # organoid segmentation in brightfield microscopy images. + datasets.get_orgasegment_dataset( + path=os.path.join(input_path, "orgasegment"), patch_shape=patch_shape, download=True, split=split_choice, + raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, sampler=sampler, ), ] - if split_choice == "train": - _datasets += get_ctc_datasets( - input_path, patch_shape, sampler, raw_transform=to_8bit, label_transform=get_label_transform() - ) + + # Add LIVECell dataset: cell segmentation for phase contrast microscopy images. + _datasets.extend(get_livecell_datasets()) + + # Add EmbedSeg datasets: cell and nuclei segmentation for fluorescence microscopy images. + _datasets.extend(get_embedseg_datasets()) + + # Add YeaZ datasets: yeast segmentation for brightfield and phase contrast microscopy images. + _datasets.extend(get_yeaz_dataset()) + + # Add CVZ Fluo datasets: cell and nuclei segmentation for fluorescence microscopy images. + _datasets.extend(get_cvz_dataset("cell")) + _datasets.extend(get_cvz_dataset("dapi")) + + # Add CTC datasets: cell segmentation for various + if split_choice == "train": # NOTE: We use CTC only for training. + _datasets.extend(get_ctc_datasets()) generalist_dataset = ConcatDataset(*_datasets) - # increasing the sampling attempts for the neurips cellseg dataset + # Increasing the sampling attempts for the NeurIPS CellSeg dataset. generalist_dataset.datasets[3].max_sampling_attempts = 5000 return generalist_dataset def get_generalist_lm_loaders(input_path, patch_shape): - """This returns the concatenated light microscopy datasets implemented in torch_em: - https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets - It will automatically download all the datasets - - expect NeurIPS CellSeg (Multi-Modal Microscopy Images) (https://neurips22-cellseg.grand-challenge.org/) - - NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) - in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. + """This returns the concatenated light microscopy datasets implemented in `torch_em`: + https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/light_microscopy. + It will automatically download all the datasets, except: + - TissueNet (see `torch_em/data/datasets/light_microscopy/tissuenet.py` for details) + - DynamicNuclearNet (see `torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py` for details) + - CellPose (see `torch_em/data/datasets/light_microscopy/cellpose.py` for details) + - YeaZ (see `torch_em/data/datasets/light_microscopy/yeaz.py` for details) + + NOTE: To remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) + in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format, i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. """ + # Get the datasets. generalist_train_dataset = get_concat_lm_datasets(input_path, patch_shape, "train") generalist_val_dataset = get_concat_lm_datasets(input_path, patch_shape, "val") + + # Get the dataloaders. train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) + return train_loader, val_loader diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index f07a4f3a..e53d4204 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -3,9 +3,6 @@ import torch -from torch_em.model import UNETR -from torch_em.loss import DiceBasedDistanceLoss - import micro_sam.training as sam_training from micro_sam.util import export_custom_sam_model @@ -13,7 +10,7 @@ def finetune_lm_generalist(args): - """Code for finetuning SAM on multiple light microscopy datasets""" + """Code for finetuning SAM on multiple Light Microscopy datasets.""" # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -22,92 +19,53 @@ def finetune_lm_generalist(args): checkpoint_path = None # override this to start training from a custom checkpoint patch_shape = (512, 512) # the patch shape for training n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled (default: 25) - freeze_parts = None # override this to freeze one or more of these backbones - - # get the trainable segment anything model - model = sam_training.get_trainable_sam_model( - model_type=model_type, - device=device, - checkpoint_path=checkpoint_path, - freeze=freeze_parts - ) - model.to(device) - - # let's get the UNETR model for automatic instance segmentation pipeline - unetr = UNETR( - backbone="sam", - encoder=model.sam.image_encoder, - out_channels=3, - use_sam_stats=True, - final_activation="Sigmoid", - use_skip_connection=False, - resize_input=True, - use_conv_transpose=True - ) - unetr.to(device) - - # let's get the parameters for SAM and the decoder from UNETR - joint_model_params = [params for params in model.parameters()] # sam parameters - for name, params in unetr.named_parameters(): # unetr's decoder parameters - if not name.startswith("encoder"): - joint_model_params.append(params) + checkpoint_name = f"{args.model_type}/lm_generalist_sam" # all the stuff we need for training - optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) train_loader, val_loader = get_generalist_lm_loaders(input_path=args.input_path, patch_shape=patch_shape) + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 5, "verbose": True} - # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) - - checkpoint_name = f"{args.model_type}/lm_generalist_sam" - - # the trainer which performs the joint training and validation (implemented using "torch_em") - trainer = sam_training.JointSamTrainer( + # Run training. + sam_training.train_sam( name=checkpoint_name, - save_root=args.save_root, + model_type=model_type, train_loader=train_loader, val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=sam_training.JointSamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, + early_stopping=None, # NOTE: Avoid early stopping for training the generalist model. n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=8, - compile_model=False, - mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training - unetr=unetr, - instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), - instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + checkpoint_path=checkpoint_path, + with_segmentation_decoder=True, + device=device, + lr=1e-5, + n_iterations=args.iterations, + save_root=args.save_root, + scheduler_kwargs=scheduler_kwargs, + verify_n_labels_in_loader=None, # NOTE: Verifies all labels in the loader(s). + box_distortion_factor=0.05, + load_weights=(not args.from_scratch), ) - trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: checkpoint_path = os.path.join( "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" ) export_custom_sam_model( - checkpoint_path=checkpoint_path, - model_type=model_type, - save_path=args.export_path, + checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path ) def main(): parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LM datasets.") parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/", + "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam/data", help="The filepath to all the respective LM datasets. If the data does not exist yet it will be downloaded" ) parser.add_argument( "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + help="The model type to use for fine-tuning. Either 'vit_t', 'vit_b', 'vit_l' or 'vit_h'." ) parser.add_argument( - "--save_root", "-s", + "--save_root", "-s", default=None, help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." ) parser.add_argument( @@ -119,11 +77,10 @@ def main(): help="Where to export the finetuned model to. The exported model can be used in the annotation tools." ) parser.add_argument( - "--save_every_kth_epoch", type=int, default=None, - help="To save every kth epoch while fine-tuning. Expects an integer value." + "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) parser.add_argument( - "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." + "--from_scratch", action="store_true", help="Whether to train Segment Anything model from scratch." ) args = parser.parse_args() finetune_lm_generalist(args) diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index 7562e374..e3d3d3f8 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -5,43 +5,37 @@ N_OBJECTS = { - "vit_t": 50, + # "vit_t": 50, "vit_b": 40, "vit_l": 30, - "vit_h": 25 + # "vit_h": 25, } -def write_batch_script(out_path, _name, env_name, model_type, save_root): +def write_batch_script(out_path, _name, env_name, model_type, save_root, dry): "Writing scripts with different micro-sam finetunings." batch_script = f"""#!/bin/bash #SBATCH -t 14-00:00:00 -#SBATCH --mem 64G #SBATCH --nodes=1 #SBATCH --ntasks=1 -#SBATCH -p grete:shared -#SBATCH -G A100:1 -#SBATCH -A nim00007 +#SBATCH -p grete-h100:shared +#SBATCH -G H100:1 +#SBATCH -A gzz0001 #SBATCH -c 16 +#SBATCH --mem 64G #SBATCH --qos=14d -#SBATCH --constraint=80gb #SBATCH --job-name={os.path.split(_name)[-1]} -source activate {env_name} \n""" +source ~/.bashrc +micromamba activate {env_name} \n""" - # python script + # The python script python_script = f"python {_name}.py " + python_script += f"-s {save_root} " # The save root folder + python_script += f"-m {model_type} " # The name of the model configuration + python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " # The choice of the number of objects - # save root folder - python_script += f"-s {save_root} " - - # name of the model configuration - python_script += f"-m {model_type} " - - # choice of the number of objects - python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " - - # let's add the python script to the bash script + # Add the python script to the bash script batch_script += python_script _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh" @@ -49,7 +43,8 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root): f.write(batch_script) cmd = ["sbatch", _op] - subprocess.run(cmd) + if not dry: + subprocess.run(cmd) def get_batch_script_names(tmp_folder): @@ -101,17 +96,17 @@ def submit_slurm(args): write_batch_script( out_path=get_batch_script_names(tmp_folder), _name=script_name, - env_name="mobilesam" if model_type == "vit_t" else "sam", + env_name="mobilesam" if model_type == "vit_t" else "super", model_type=model_type, - save_root=args.save_root + save_root=args.save_root, + dry=args.dry, ) def main(args): - try: + tmp_dir = "./gpu_jobs" + if os.path.exists(tmp_dir): shutil.rmtree("./gpu_jobs") - except FileNotFoundError: - pass submit_slurm(args) @@ -119,8 +114,19 @@ def main(args): if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("-e", "--experiment_name", type=str, default=None) - parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/micro-sam/") - parser.add_argument("-m", "--model_type", type=str, default=None) + parser.add_argument( + "-e", "--experiment_name", type=str, default=None, help="The choice of experiment name.", + ) + parser.add_argument( + "-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam", + help="The path where to store the model checkpoints and logs.", + ) + parser.add_argument( + "-m", "--model_type", type=str, default=None, help="The choice of model type for Segment Anything model." + ) + parser.add_argument( + "--dry", action="store_true", help="Whether to submit the scripts to slurm or only store the scripts." + ) args = parser.parse_args() + main(args) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index d72295c6..2fbb4c3d 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -114,8 +114,13 @@ def forward(self, x): new_v = self.FacTv(new_v) # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. - qkv[:, :, :, : self.dim] += new_q - qkv[:, :, :, -self.dim:] += new_v + qkv = torch.cat( + [ + qkv[:, :, :, :self.dim] + new_q, # replacing new q values + qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical + qkv[:, :, :, -self.dim:] + new_v # replacing new v values + ], dim=-1 + ) return qkv diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index c576e9df..1edef97d 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -188,6 +188,8 @@ def train_sam( peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, + box_distortion_factor: Optional[float] = 0.025, + load_weights: bool = True, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -225,6 +227,8 @@ def train_sam( ignore_warnings: Whether to ignore raised warnings. verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders. + box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. + load_weights: Whether to initialize the model with pretrained parameter weights. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. """ with _filter_warnings(ignore_warnings): @@ -243,11 +247,12 @@ def train_sam( checkpoint_path=checkpoint_path, return_state=True, peft_kwargs=peft_kwargs, + load_weights=load_weights, **model_kwargs ) # This class creates all the training data for a batch (inputs, prompts and labels). - convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) # Create the UNETR decoder (if train with it) and the optimizer. if with_segmentation_decoder: diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index f29cbb67..1097efa4 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -1,6 +1,6 @@ import os from math import ceil, floor -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Tuple import numpy as np @@ -18,6 +18,7 @@ from torch_em.transform.label import PerObjectDistanceTransform from torch_em.transform.raw import normalize_percentile, normalize +from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb def identity(x): @@ -46,21 +47,22 @@ def get_trainable_sam_model( return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, + load_weights: bool = True, **model_kwargs ) -> TrainableSAM: """Get the trainable sam model. Args: - model_type: The segment anything model that should be finetuned. - The weights of this model will be used for initialization, unless a - custom weight file is passed via `checkpoint_path`. + model_type: The segment anything model that should be finetuned. The weights of this model + will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. device: The device to use for training. checkpoint_path: Path to a custom checkpoint from which to load the model weights. - freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder - By default nothing is frozen and the full model is updated. + freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and + `mask_decoder`. By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. peft_kwargs: Keyword arguments for the PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. + load_weights: Whether to initialize the model with pretrained parameter weights. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. Returns: @@ -75,6 +77,7 @@ def get_trainable_sam_model( return_sam=True, return_state=True, flexible_load_checkpoint=flexible_load_checkpoint, + load_weights=load_weights, **model_kwargs ) @@ -185,11 +188,13 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): get_points = True # keeping the solution open by checking for deterministic/dynamic choice of point prompts - prompt_generator = PointAndBoxPromptGenerator(n_positive_points=n_pos, - n_negative_points=n_neg, - dilation_strength=self.dilation_strength, - get_box_prompts=get_boxes, - get_point_prompts=get_points) + prompt_generator = PointAndBoxPromptGenerator( + n_positive_points=n_pos, + n_negative_points=n_neg, + dilation_strength=self.dilation_strength, + get_box_prompts=get_boxes, + get_point_prompts=get_points + ) batched_inputs = [] batched_sampled_cell_ids_list = [] @@ -215,6 +220,7 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): batched_input["boxes"] = self.transform.apply_boxes_torch( box_prompts, original_size=gt.shape[-2:] ) if self.transform is not None else box_prompts + if get_points: batched_input["point_coords"] = self.transform.apply_coords_torch( point_prompts, original_size=gt.shape[-2:] @@ -252,34 +258,43 @@ def normalize_to_8bit(raw): class ResizeRawTrafo: - def __init__(self, desired_shape, do_rescaling=False, padding="constant"): + def __init__( + self, + desired_shape: Tuple[int, ...], + do_rescaling: bool = False, + valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, + padding: str = "constant" + ): self.desired_shape = desired_shape - self.padding = padding self.do_rescaling = do_rescaling + self.valid_channels = valid_channels + self.padding = padding def __call__(self, raw): + raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. + if self.do_rescaling: - raw = normalize_percentile(raw, axis=(1, 2)) - raw = np.mean(raw, axis=0) + raw = normalize_percentile(raw, axis=self.valid_channels) raw = normalize(raw) raw = raw * 255 - tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - raw = np.pad( - raw, - pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), - mode=self.padding - ) + # Pad the inputs to the desired shape. + tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)] + ddim = [(per_dim / 2) for per_dim in tmp_ddim] + pad_width = [(ceil(d), floor(d)) for d in ddim] + raw = np.pad(raw, pad_width=pad_width, mode=self.padding) + assert raw.shape == self.desired_shape return raw class ResizeLabelTrafo: - def __init__(self, desired_shape, padding="constant", min_size=0): + def __init__( + self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant", + ): self.desired_shape = desired_shape - self.padding = padding self.min_size = min_size + self.padding = padding def __call__(self, labels): distance_trafo = PerObjectDistanceTransform( diff --git a/micro_sam/util.py b/micro_sam/util.py index ba12e550..867a47ec 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -277,6 +277,7 @@ def get_sam_model( return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, + load_weights: bool = True, **model_kwargs, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. @@ -301,17 +302,18 @@ def get_sam_model( https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html Args: - model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default. To get a list of all available model names you can call `get_model_names`. device: The device for the model. If none is given will use GPU if available. checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to `model_type`. If given, `model_type` must match the architecture - corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder + corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. peft_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. + load_weights: Whether to initialize the model with pretrained parameter weights. model_kwargs: Additional parameters necessary to initialize the Segment Anything model. Returns: @@ -354,7 +356,7 @@ def get_sam_model( raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( - "mobile_sam is required for the vit-tiny." + "'mobile_sam' is required for the vit-tiny. " "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" ) @@ -378,10 +380,11 @@ def get_sam_model( sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. - if flexible_load_checkpoint: - sam = _handle_checkpoint_loading(sam, model_state) - else: - sam.load_state_dict(model_state) + if load_weights: + if flexible_load_checkpoint: + sam = _handle_checkpoint_loading(sam, model_state) + else: + sam.load_state_dict(model_state) sam.to(device=device) @@ -990,7 +993,12 @@ def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional n_ids = int(segmentation.max()) else: - assert segmentation_ids[0] != 0, "No objects were found." + msg = "No foreground objects were found." + if len(segmentation_ids) == 0: # The list should not be completely empty. + raise RuntimeError(msg) + + if 0 in segmentation_ids: # The list should not have 'zero' as a value. + raise RuntimeError(msg) # the segmentation ids have to be sorted segmentation_ids = np.sort(segmentation_ids)