From 8c8b648b5070ffa7c3009e31e33f5f1be3d07e95 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Mon, 30 Dec 2024 22:44:49 +0100 Subject: [PATCH 01/16] Refactor lm generalist datasets and add new datasets --- .../light_microscopy/obtain_lm_datasets.py | 208 +++++++++++++----- .../light_microscopy/train_lm_generalist.py | 82 ++----- micro_sam/training/training.py | 2 + micro_sam/training/util.py | 18 +- 4 files changed, 182 insertions(+), 128 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 8ac629b4f..e6e26c030 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -1,124 +1,220 @@ import os + import numpy as np +from sklearn.model_selection import train_test_split import torch + import torch_em -import torch_em.data.datasets as datasets +from torch_em.data import datasets +from torch_em.transform.raw import normalize from torch_em.data import MinInstanceSampler, ConcatDataset from torch_em.transform.label import PerObjectDistanceTransform -from torch_em.transform.raw import normalize_percentile, normalize +from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb -from micro_sam.training import identity from micro_sam.training.util import ResizeRawTrafo, ResizeLabelTrafo -def neurips_raw_trafo(raw): - raw = datasets.neurips_cell_seg.to_rgb(raw) # ensures 3 channels for the neurips data - raw = normalize_percentile(raw) - raw = np.mean(raw, axis=0) - raw = normalize(raw) - raw = raw * 255 +def _to_8bit(raw): + "Ensures three channels for inputs and rescale them to 8 bit." + if raw.ndim == 2: + raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. + else: + if raw.shape[0] != 3: + assert raw.shape[0] == 1, raw.shape + raw = np.concatenate([raw] * 3, axis=0) + + raw = normalize(raw) * 255 return raw -def to_8bit(raw): - raw = normalize(raw) - raw = raw * 255 - return raw +def _identity(x): + "Ensures three channels for inputs and avoids rescaling inputs." + x = to_rgb(x) + return x def get_concat_lm_datasets(input_path, patch_shape, split_choice): assert split_choice in ["train", "val"] label_dtype = torch.float32 - sampler = MinInstanceSampler() + sampler = MinInstanceSampler(min_size=10) - def get_label_transform(min_size=0): + def _get_label_transform(min_size=10): label_transform = PerObjectDistanceTransform( - distances=True, boundary_distances=True, directed_distances=False, - foreground=True, instances=True, min_size=min_size + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + instances=True, + min_size=min_size ) return label_transform - def get_ctc_datasets( - input_path, patch_shape, sampler, raw_transform, label_transform, - ignore_datasets=["Fluo-N2DH-GOWT1", "Fluo-N2DL-HeLa"] - ): + def get_embedseg_datasets(): + "Datasets for cell and nuclei segmentation in fluorescence microscopy images." + names = [ + "Mouse-Organoid-Cells-CBG", + "Mouse-Skull-Nuclei-CBG", + "Platynereis-ISH-Nuclei-CBG", + "Platynereis-Nuclei-CBG", + ] + all_embedseg_datasets = [ + datasets.get_embedseg_dataset( + path=os.path.join(input_path, "embedseg"), name=name, patch_shape=(1, *patch_shape), ndim=2, + download=True, n_samples=500 if split_choice == "train" else 100, sampler=sampler, + raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, + ) for name in names + ] + return all_embedseg_datasets + + def get_yeaz_dataset(): + "Datasets for yeast segmentation in phase contrast and brightfield microscopy images." + names = ["bf", "phc"] + all_yeaz_datasets = [ + datasets.get_yeaz_dataset( + path=os.path.join(input_path, "yeaz"), patch_shape=patch_shape, raw_transform=_to_8bit, + ndim=2, download=True, split=split_choice, choice=name, label_transform=_get_label_transform(), + sampler=sampler, label_dtype=label_dtype, + ) for name in names + ] + return all_yeaz_datasets + + def get_cvz_dataset(stain_choice): + "Datasets for cell and nuclei segmentation in fluorescence microscopy images." + # NOTE: We create random splits for this dataset for training the generalist. + raw_paths, label_paths = datasets.cvz_fluo.get_cvz_fluo_paths( + path=os.path.join(input_path, "cvz"), stain_choice=stain_choice, + ) + train_raw_paths, test_raw_paths, train_label_paths, test_label_paths = train_test_split( + raw_paths, label_paths, test_size=0.2, random_state=42, + ) + ds = torch_em.default_segmentation_dataset( + raw_paths=train_raw_paths if split_choice == "train" else test_raw_paths, raw_key=None, + label_paths=train_label_paths if split_choice == "train" else test_label_paths, + label_key=None, is_seg_dataset=False, patch_shape=patch_shape, sampler=sampler, + raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, + ) + return [ds] + + def get_ctc_datasets(): + "Datasets for cell segmentation in different modalities." all_ctc_datasets = [] for dataset_name in datasets.ctc.CTC_CHECKSUMS["train"].keys(): - if dataset_name in ignore_datasets: + if dataset_name in ["Fluo-N2DH-GOWT1", "Fluo-N2DL-HeLa"]: continue all_ctc_datasets.append( datasets.get_ctc_segmentation_dataset( path=os.path.join(input_path, "ctc"), dataset_name=dataset_name, patch_shape=(1, *patch_shape), - sampler=sampler, raw_transform=raw_transform, label_transform=label_transform, download=True + sampler=sampler, raw_transform=_to_8bit, label_transform=_get_label_transform(), + download=True, label_dtype=label_dtype, ) ) return all_ctc_datasets _datasets = [ + # cell segmentation in phase contrast microscopy images. + datasets.get_livecell_dataset( + path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, download=True, + sampler=sampler, label_dtype=label_dtype, raw_transform=_identity, label_transform=_get_label_transform(), + ), + # cell segmentation in tissue microscopy images. datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, - raw_channel="rgb", label_channel="cell", sampler=sampler, label_dtype=label_dtype, - raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=True), - label_transform=ResizeLabelTrafo(patch_shape, min_size=0), - n_samples=1000 if split_choice == "train" else 100 - ), - datasets.get_livecell_dataset( - path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, - download=True, label_transform=get_label_transform(), sampler=sampler, - label_dtype=label_dtype, raw_transform=identity + raw_channel="rgb", label_channel="cell", raw_transform=ResizeRawTrafo((3, *patch_shape), do_rescaling=True), + label_transform=ResizeLabelTrafo(patch_shape, min_size=10), sampler=sampler, label_dtype=label_dtype, + n_samples=500 if split_choice == "train" else 100, ), + # bacteria segmentation in label-free microscopy images. datasets.get_deepbacs_dataset( path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, - raw_transform=to_8bit, label_transform=get_label_transform(), label_dtype=label_dtype, - download=True, sampler=MinInstanceSampler(min_num_instances=4) + raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, + download=True, sampler=MinInstanceSampler(min_num_instances=4, min_size=10) ), + # cell segmentation in confocal microscopy images. + datasets.get_plantseg_dataset( + path=os.path.join(input_path, "plantseg"), name="root", n_samples=1000 if split_choice == "train" else 100, + patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), + sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, + label_transform=ResizeLabelTrafo(patch_shape, min_size=10), + ), + # cell segmentation in multi-modal microscopy images. datasets.get_neurips_cellseg_supervised_dataset( - root=os.path.join(input_path, "neurips-cell-seg"), split=split_choice, - patch_shape=patch_shape, raw_transform=neurips_raw_trafo, label_transform=get_label_transform(), - label_dtype=label_dtype, sampler=MinInstanceSampler(min_num_instances=3) + root=os.path.join(input_path, "neurips_cellseg"), split=split_choice, label_dtype=label_dtype, + patch_shape=patch_shape, raw_transform=_to_8bit, label_transform=_get_label_transform(), + sampler=MinInstanceSampler(min_num_instances=3, min_size=10), download=True, ), + # nuclei segmentation in fluorescence microscopy images. datasets.get_dsb_dataset( path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", - patch_shape=patch_shape, label_transform=get_label_transform(), sampler=sampler, - label_dtype=label_dtype, download=True, raw_transform=identity + patch_shape=patch_shape, label_transform=_get_label_transform(), sampler=sampler, + label_dtype=label_dtype, download=True, raw_transform=_identity, ), - datasets.get_plantseg_dataset( - path=os.path.join(input_path, "plantseg"), name="root", sampler=MinInstanceSampler(min_num_instances=10), - patch_shape=(1, *patch_shape), download=True, split=split_choice, ndim=2, label_dtype=label_dtype, - raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=False), - label_transform=ResizeLabelTrafo(patch_shape, min_size=0), - n_samples=1000 if split_choice == "train" else 100 + # nuclei segmentation in fluorescence microscopy images. + datasets.get_dynamicnuclearnet_dataset( + path=os.path.join(input_path, "dynamicnuclearnet"), patch_shape=patch_shape, download=True, sampler=sampler, + split=split_choice, n_samples=500 if split_choice == "train" else 100, label_dtype=label_dtype, + raw_transform=_to_8bit, label_transform=_get_label_transform(), + ), + # cell segmentation in multiple microscopy imaging modalities. + datasets.get_cellpose_dataset( + path=os.path.join(input_path, "cellpose"), patch_shape=patch_shape, choice="cyto", raw_transform=_identity, + download=True, split=split_choice if split_choice == "train" else "test", sampler=sampler, + label_dtype=label_dtype, label_transform=_get_label_transform(), + ), + # bacteria segmentation in phase contrast and fluorescence microscopy images. + # worm segmentation in brightfield microscopy images. + datasets.get_omnipose_dataset( + path=os.path.join(input_path, "omnipose"), patch_shape=patch_shape, download=True, + split=split_choice if split_choice == "train" else "test", sampler=sampler, label_dtype=label_dtype, + raw_transform=_to_8bit, label_transform=_get_label_transform(), + ), + # organoid segmentation in brightfield microscopy images. + datasets.get_orgasegment_dataset( + path=os.path.join(input_path, "orgasegment"), patch_shape=patch_shape, download=True, split=split_choice, + raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, sampler=sampler, ), ] - if split_choice == "train": - _datasets += get_ctc_datasets( - input_path, patch_shape, sampler, raw_transform=to_8bit, label_transform=get_label_transform() - ) + + # Add EmbedSeg datasets: cell and nuclei segmentation for fluorescence microscopy images. + _datasets.extend(get_embedseg_datasets()) + + # Add YeaZ datasets: yeast segmentation for brightfield and phase contrast microscopy images. + _datasets.extend(get_yeaz_dataset()) + + # Add CVZ Fluo datasets: cell and nuclei segmentation for fluorescence microscopy images. + _datasets.extend(get_cvz_dataset("cell")) + _datasets.extend(get_cvz_dataset("dapi")) + + # Add CTC datasets: cell segmentation for various + if split_choice == "train": # NOTE: We use CTC only for training. + _datasets.extend(get_ctc_datasets()) generalist_dataset = ConcatDataset(*_datasets) - # increasing the sampling attempts for the neurips cellseg dataset + # Increasing the sampling attempts for the NeurIPS CellSeg dataset. generalist_dataset.datasets[3].max_sampling_attempts = 5000 return generalist_dataset def get_generalist_lm_loaders(input_path, patch_shape): - """This returns the concatenated light microscopy datasets implemented in torch_em: - https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets - It will automatically download all the datasets - - expect NeurIPS CellSeg (Multi-Modal Microscopy Images) (https://neurips22-cellseg.grand-challenge.org/) + """This returns the concatenated light microscopy datasets implemented in `torch_em`: + https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/light_microscopy. + It will automatically download all the datasets. - NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) - in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. + NOTE: To remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) + in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format, i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. """ + # Get the datasets. generalist_train_dataset = get_concat_lm_datasets(input_path, patch_shape, "train") generalist_val_dataset = get_concat_lm_datasets(input_path, patch_shape, "val") + + # Get the dataloaders. train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) + return train_loader, val_loader diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index f07a4f3ab..32b76037c 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -3,9 +3,6 @@ import torch -from torch_em.model import UNETR -from torch_em.loss import DiceBasedDistanceLoss - import micro_sam.training as sam_training from micro_sam.util import export_custom_sam_model @@ -13,7 +10,7 @@ def finetune_lm_generalist(args): - """Code for finetuning SAM on multiple light microscopy datasets""" + """Code for finetuning SAM on multiple Light Microscopy datasets.""" # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -22,69 +19,30 @@ def finetune_lm_generalist(args): checkpoint_path = None # override this to start training from a custom checkpoint patch_shape = (512, 512) # the patch shape for training n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled (default: 25) - freeze_parts = None # override this to freeze one or more of these backbones - - # get the trainable segment anything model - model = sam_training.get_trainable_sam_model( - model_type=model_type, - device=device, - checkpoint_path=checkpoint_path, - freeze=freeze_parts - ) - model.to(device) - - # let's get the UNETR model for automatic instance segmentation pipeline - unetr = UNETR( - backbone="sam", - encoder=model.sam.image_encoder, - out_channels=3, - use_sam_stats=True, - final_activation="Sigmoid", - use_skip_connection=False, - resize_input=True, - use_conv_transpose=True - ) - unetr.to(device) - - # let's get the parameters for SAM and the decoder from UNETR - joint_model_params = [params for params in model.parameters()] # sam parameters - for name, params in unetr.named_parameters(): # unetr's decoder parameters - if not name.startswith("encoder"): - joint_model_params.append(params) + checkpoint_name = f"{args.model_type}/lm_generalist_sam" # all the stuff we need for training - optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) train_loader, val_loader = get_generalist_lm_loaders(input_path=args.input_path, patch_shape=patch_shape) + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} - # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) - - checkpoint_name = f"{args.model_type}/lm_generalist_sam" - - # the trainer which performs the joint training and validation (implemented using "torch_em") - trainer = sam_training.JointSamTrainer( + # Run training. + sam_training.train_sam( name=checkpoint_name, - save_root=args.save_root, + model_type=model_type, train_loader=train_loader, val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=sam_training.JointSamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, + early_stopping=None, # NOTE: Avoid early stopping for training the generalist model. n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=8, - compile_model=False, - mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training - unetr=unetr, - instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), - instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + checkpoint_path=checkpoint_path, + with_segmentation_decoder=True, + device=device, + lr=1e-5, + n_iterations=args.iterations, + save_root=args.save_root, + scheduler_kwargs=scheduler_kwargs, + verify_n_labels_in_loader=None, # NOTE: Verifies all labels in the loader(s). ) - trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: checkpoint_path = os.path.join( "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" @@ -99,7 +57,7 @@ def finetune_lm_generalist(args): def main(): parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LM datasets.") parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/", + "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam/data", help="The filepath to all the respective LM datasets. If the data does not exist yet it will be downloaded" ) parser.add_argument( @@ -107,7 +65,7 @@ def main(): help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." ) parser.add_argument( - "--save_root", "-s", + "--save_root", "-s", default=None, help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." ) parser.add_argument( @@ -118,10 +76,6 @@ def main(): "--export_path", "-e", help="Where to export the finetuned model to. The exported model can be used in the annotation tools." ) - parser.add_argument( - "--save_every_kth_epoch", type=int, default=None, - help="To save every kth epoch while fine-tuning. Expects an integer value." - ) parser.add_argument( "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 79485b8d0..c0826757a 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -238,6 +238,8 @@ def train_sam( _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) + breakpoint() + device = get_device(device) # Get the trainable segment anything model. model, state = get_trainable_sam_model( diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index f29cbb670..72ac9207d 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -18,6 +18,7 @@ from torch_em.transform.label import PerObjectDistanceTransform from torch_em.transform.raw import normalize_percentile, normalize +from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb def identity(x): @@ -258,19 +259,20 @@ def __init__(self, desired_shape, do_rescaling=False, padding="constant"): self.do_rescaling = do_rescaling def __call__(self, raw): + raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. + if self.do_rescaling: + # NOTE: Below is done for TissueNet: to work with the valid channels. raw = normalize_percentile(raw, axis=(1, 2)) - raw = np.mean(raw, axis=0) raw = normalize(raw) raw = raw * 255 - tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - raw = np.pad( - raw, - pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), - mode=self.padding - ) + # Pad the inputs to the desired shape. + tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)] + ddim = [(per_dim / 2) for per_dim in tmp_ddim] + pad_width = [(ceil(d), floor(d)) for d in ddim] + raw = np.pad(raw, pad_width=pad_width, mode=self.padding) + assert raw.shape == self.desired_shape return raw From 379af0196d20f128d7e2c1c30d820dc2cac2da92 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 00:19:49 +0100 Subject: [PATCH 02/16] Modify .gitignore file --- .gitignore | 6 ++++++ finetuning/.gitignore | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 finetuning/.gitignore diff --git a/.gitignore b/.gitignore index 734e5772e..cbb13274d 100644 --- a/.gitignore +++ b/.gitignore @@ -174,9 +174,15 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# Additional stuff to avoid tracking. # Torch-em stuff checkpoints/ logs/ # "gpu_jobs" folder where slurm batch submission scripts are saved gpu_jobs/ +sam_embeddings/ +results/ +iterative_prompting_results/ +*.sh +*.png diff --git a/finetuning/.gitignore b/finetuning/.gitignore deleted file mode 100644 index b078b91d0..000000000 --- a/finetuning/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -checkpoints/ -logs/ -sam_embeddings/ -results/ -iterative_prompting_results/ -*.sh From fa764d5c0f985f84daaf9505a9712dd12b89cb20 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 11:49:28 +0100 Subject: [PATCH 03/16] Document datasets without auto download support --- .../training/light_microscopy/obtain_lm_datasets.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index e6e26c030..935b63a5d 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -202,7 +202,11 @@ def get_ctc_datasets(): def get_generalist_lm_loaders(input_path, patch_shape): """This returns the concatenated light microscopy datasets implemented in `torch_em`: https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/light_microscopy. - It will automatically download all the datasets. + It will automatically download all the datasets, except: + - TissueNet (see `torch_em/data/datasets/light_microscopy/tissuenet.py` for details) + - DynamicNuclearNet (see `torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py` for details) + - CellPose (see `torch_em/data/datasets/light_microscopy/cellpose.py` for details) + - YeaZ (see `torch_em/data/datasets/light_microscopy/yeaz.py` for details) NOTE: To remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format, From e19f4233f473473b5aa4e0fb51c2ea9bff37b4e5 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 13:24:26 +0100 Subject: [PATCH 04/16] Remove breakpoint --- micro_sam/training/training.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index c0826757a..79485b8d0 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -238,8 +238,6 @@ def train_sam( _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) - breakpoint() - device = get_device(device) # Get the trainable segment anything model. model, state = get_trainable_sam_model( From 31bc3983361bfefbbcaa7954d28288a64db8a31d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 15:16:15 +0100 Subject: [PATCH 05/16] Fix tissuenet normalization --- micro_sam/training/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 72ac9207d..35a767c62 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -263,7 +263,7 @@ def __call__(self, raw): if self.do_rescaling: # NOTE: Below is done for TissueNet: to work with the valid channels. - raw = normalize_percentile(raw, axis=(1, 2)) + raw = normalize_percentile(raw, axis=(0, 1)) raw = normalize(raw) raw = raw * 255 From adc2129a353f118d37a11b608b64203997d9f752 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 18:37:20 +0100 Subject: [PATCH 06/16] Refactor livecell dataset --- .../light_microscopy/obtain_lm_datasets.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 935b63a5d..0a91d39e6 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -51,6 +51,17 @@ def _get_label_transform(min_size=10): ) return label_transform + def get_livecell_datasets(): + "Datasets for cell segmentation in phase contrast microscopy images." + all_livecell_datasets = [ + datasets.get_livecell_dataset( + path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, + sampler=sampler, label_dtype=label_dtype, raw_transform=_identity, download=True, cell_types=[ctype], + label_transform=_get_label_transform(), n_samples=200 if split_choice == "train" else None, + ) for ctype in datasets.livecell.CELL_TYPES + ] + return all_livecell_datasets + def get_embedseg_datasets(): "Datasets for cell and nuclei segmentation in fluorescence microscopy images." names = [ @@ -114,11 +125,6 @@ def get_ctc_datasets(): return all_ctc_datasets _datasets = [ - # cell segmentation in phase contrast microscopy images. - datasets.get_livecell_dataset( - path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, download=True, - sampler=sampler, label_dtype=label_dtype, raw_transform=_identity, label_transform=_get_label_transform(), - ), # cell segmentation in tissue microscopy images. datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, @@ -177,6 +183,9 @@ def get_ctc_datasets(): ), ] + # Add LIVECell dataset: cell segmentation for phase contrast microscopy images. + _datasets.extend(get_livecell_datasets()) + # Add EmbedSeg datasets: cell and nuclei segmentation for fluorescence microscopy images. _datasets.extend(get_embedseg_datasets()) From c9a6c8fee296e10e9f98747120daa90e0245cd28 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 18:45:08 +0100 Subject: [PATCH 07/16] Reduce plantseg and livecell-val further --- .../training/light_microscopy/obtain_lm_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 0a91d39e6..0bb8e1692 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -57,7 +57,7 @@ def get_livecell_datasets(): datasets.get_livecell_dataset( path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, sampler=sampler, label_dtype=label_dtype, raw_transform=_identity, download=True, cell_types=[ctype], - label_transform=_get_label_transform(), n_samples=200 if split_choice == "train" else None, + label_transform=_get_label_transform(), n_samples=200 if split_choice == "train" else 50, ) for ctype in datasets.livecell.CELL_TYPES ] return all_livecell_datasets @@ -140,7 +140,7 @@ def get_ctc_datasets(): ), # cell segmentation in confocal microscopy images. datasets.get_plantseg_dataset( - path=os.path.join(input_path, "plantseg"), name="root", n_samples=1000 if split_choice == "train" else 100, + path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, label_transform=ResizeLabelTrafo(patch_shape, min_size=10), From f547a5d6d4e238cb73144e45499b62af4dbcccc5 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 31 Dec 2024 20:10:16 +0100 Subject: [PATCH 08/16] Add debugging scripts --- .../light_microscopy/obtain_lm_datasets.py | 114 ++++++++++-------- .../light_microscopy/train_lm_generalist.py | 2 +- 2 files changed, 62 insertions(+), 54 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 0bb8e1692..e51cbcf57 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -132,78 +132,81 @@ def get_ctc_datasets(): label_transform=ResizeLabelTrafo(patch_shape, min_size=10), sampler=sampler, label_dtype=label_dtype, n_samples=500 if split_choice == "train" else 100, ), - # bacteria segmentation in label-free microscopy images. - datasets.get_deepbacs_dataset( - path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, - raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, - download=True, sampler=MinInstanceSampler(min_num_instances=4, min_size=10) - ), - # cell segmentation in confocal microscopy images. - datasets.get_plantseg_dataset( - path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, - patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), - sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, - label_transform=ResizeLabelTrafo(patch_shape, min_size=10), - ), - # cell segmentation in multi-modal microscopy images. - datasets.get_neurips_cellseg_supervised_dataset( - root=os.path.join(input_path, "neurips_cellseg"), split=split_choice, label_dtype=label_dtype, - patch_shape=patch_shape, raw_transform=_to_8bit, label_transform=_get_label_transform(), - sampler=MinInstanceSampler(min_num_instances=3, min_size=10), download=True, - ), - # nuclei segmentation in fluorescence microscopy images. - datasets.get_dsb_dataset( - path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", - patch_shape=patch_shape, label_transform=_get_label_transform(), sampler=sampler, - label_dtype=label_dtype, download=True, raw_transform=_identity, - ), - # nuclei segmentation in fluorescence microscopy images. - datasets.get_dynamicnuclearnet_dataset( - path=os.path.join(input_path, "dynamicnuclearnet"), patch_shape=patch_shape, download=True, sampler=sampler, - split=split_choice, n_samples=500 if split_choice == "train" else 100, label_dtype=label_dtype, - raw_transform=_to_8bit, label_transform=_get_label_transform(), - ), + # # bacteria segmentation in label-free microscopy images. + # datasets.get_deepbacs_dataset( + # path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, + # raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, + # download=True, sampler=MinInstanceSampler(min_num_instances=4, min_size=10) + # ), + # # cell segmentation in confocal microscopy images. + # datasets.get_plantseg_dataset( + # path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, + # patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), + # sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, + # label_transform=ResizeLabelTrafo(patch_shape, min_size=10), + # ), + # # cell segmentation in multi-modal microscopy images. + # datasets.get_neurips_cellseg_supervised_dataset( + # root=os.path.join(input_path, "neurips_cellseg"), split=split_choice, label_dtype=label_dtype, + # patch_shape=patch_shape, raw_transform=_to_8bit, label_transform=_get_label_transform(), + # sampler=MinInstanceSampler(min_num_instances=3, min_size=10), download=True, + # ), + # # nuclei segmentation in fluorescence microscopy images. + # datasets.get_dsb_dataset( + # path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", + # patch_shape=patch_shape, label_transform=_get_label_transform(), sampler=sampler, + # label_dtype=label_dtype, download=True, raw_transform=_identity, + # ), + # # nuclei segmentation in fluorescence microscopy images. + # datasets.get_dynamicnuclearnet_dataset( + # path=os.path.join(input_path, "dynamicnuclearnet"), patch_shape=patch_shape, download=True, sampler=sampler, + # split=split_choice, n_samples=500 if split_choice == "train" else 100, label_dtype=label_dtype, + # raw_transform=_to_8bit, label_transform=_get_label_transform(), + # ), # cell segmentation in multiple microscopy imaging modalities. - datasets.get_cellpose_dataset( - path=os.path.join(input_path, "cellpose"), patch_shape=patch_shape, choice="cyto", raw_transform=_identity, - download=True, split=split_choice if split_choice == "train" else "test", sampler=sampler, - label_dtype=label_dtype, label_transform=_get_label_transform(), - ), + # TODO + # datasets.get_cellpose_dataset( + # path=os.path.join(input_path, "cellpose"), patch_shape=patch_shape, choice="cyto", raw_transform=_identity, + # download=True, split=split_choice if split_choice == "train" else "test", sampler=sampler, + # label_dtype=label_dtype, label_transform=_get_label_transform(), + # ), # bacteria segmentation in phase contrast and fluorescence microscopy images. # worm segmentation in brightfield microscopy images. - datasets.get_omnipose_dataset( - path=os.path.join(input_path, "omnipose"), patch_shape=patch_shape, download=True, - split=split_choice if split_choice == "train" else "test", sampler=sampler, label_dtype=label_dtype, - raw_transform=_to_8bit, label_transform=_get_label_transform(), - ), + # datasets.get_omnipose_dataset( + # path=os.path.join(input_path, "omnipose"), patch_shape=patch_shape, download=True, + # split=split_choice if split_choice == "train" else "test", sampler=sampler, + # label_dtype=label_dtype, raw_transform=_to_8bit, label_transform=_get_label_transform(), + # ), # organoid segmentation in brightfield microscopy images. - datasets.get_orgasegment_dataset( - path=os.path.join(input_path, "orgasegment"), patch_shape=patch_shape, download=True, split=split_choice, - raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, sampler=sampler, - ), + # TODO + # datasets.get_orgasegment_dataset( + # path=os.path.join(input_path, "orgasegment"), patch_shape=patch_shape, download=True, split=split_choice, + # raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, sampler=sampler, + # ), ] # Add LIVECell dataset: cell segmentation for phase contrast microscopy images. - _datasets.extend(get_livecell_datasets()) + # _datasets.extend(get_livecell_datasets()) # Add EmbedSeg datasets: cell and nuclei segmentation for fluorescence microscopy images. - _datasets.extend(get_embedseg_datasets()) + # _datasets.extend(get_embedseg_datasets()) # Add YeaZ datasets: yeast segmentation for brightfield and phase contrast microscopy images. - _datasets.extend(get_yeaz_dataset()) + # _datasets.extend(get_yeaz_dataset()) # Add CVZ Fluo datasets: cell and nuclei segmentation for fluorescence microscopy images. - _datasets.extend(get_cvz_dataset("cell")) - _datasets.extend(get_cvz_dataset("dapi")) + # TODO + # _datasets.extend(get_cvz_dataset("cell")) + # _datasets.extend(get_cvz_dataset("dapi")) # Add CTC datasets: cell segmentation for various - if split_choice == "train": # NOTE: We use CTC only for training. - _datasets.extend(get_ctc_datasets()) + # if split_choice == "train": # NOTE: We use CTC only for training. + # _datasets.extend(get_ctc_datasets()) generalist_dataset = ConcatDataset(*_datasets) # Increasing the sampling attempts for the NeurIPS CellSeg dataset. - generalist_dataset.datasets[3].max_sampling_attempts = 5000 + # generalist_dataset.datasets[3].max_sampling_attempts = 5000 return generalist_dataset @@ -230,4 +233,9 @@ def get_generalist_lm_loaders(input_path, patch_shape): train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) + from torch_em.util.debug import check_loader + check_loader(train_loader, 16) + + breakpoint() + return train_loader, val_loader diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index 32b76037c..854f56fb6 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -23,7 +23,7 @@ def finetune_lm_generalist(args): # all the stuff we need for training train_loader, val_loader = get_generalist_lm_loaders(input_path=args.input_path, patch_shape=patch_shape) - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 5, "verbose": True} # Run training. sam_training.train_sam( From cd2d3b1bc5470ab0a2b7117195647808e66eb278 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 31 Dec 2024 22:01:06 +0100 Subject: [PATCH 09/16] Create custom raw trafo for cellpose --- .../light_microscopy/obtain_lm_datasets.py | 143 ++++++++++-------- micro_sam/training/util.py | 2 +- 2 files changed, 80 insertions(+), 65 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index e51cbcf57..31f5e513d 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -34,6 +34,28 @@ def _identity(x): return x +def _cellpose_raw_trafo(x): + """Transforms input images to desired format. + NOTE: The input channel logic is arranged a bit strangely in `cyto` dataset. + We take care of it here. + """ + r, g, b = x + + assert g.max() != 0 + if r.max() == 0: + # The image is 1 channel and exists in green channel only. + assert b.max() == 0 + x = np.concatenate([g[None]] * 3, axis=0) + + elif r.max() != 0 and g.max() != 0: + # The image is 2 channels and we sort the channels such that - 0: cell, 1: nucleus + x = np.stack([g, r, np.zeros_like(b)], axis=0) + + x = to_rgb(x) # Ensures three channels for inputs and avoids rescaling inputs. + + return x + + def get_concat_lm_datasets(input_path, patch_shape, split_choice): assert split_choice in ["train", "val"] @@ -72,9 +94,10 @@ def get_embedseg_datasets(): ] all_embedseg_datasets = [ datasets.get_embedseg_dataset( - path=os.path.join(input_path, "embedseg"), name=name, patch_shape=(1, *patch_shape), ndim=2, - download=True, n_samples=500 if split_choice == "train" else 100, sampler=sampler, - raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, + path=os.path.join(input_path, "embedseg"), name=name, patch_shape=(1, *patch_shape), + download=True, n_samples=500 if split_choice == "train" else 100, raw_transform=_to_8bit, + label_dtype=label_dtype, label_transform=_get_label_transform(), ndim=2, + sampler=MinInstanceSampler(min_num_instances=3, min_size=10), ) for name in names ] return all_embedseg_datasets @@ -132,81 +155,78 @@ def get_ctc_datasets(): label_transform=ResizeLabelTrafo(patch_shape, min_size=10), sampler=sampler, label_dtype=label_dtype, n_samples=500 if split_choice == "train" else 100, ), - # # bacteria segmentation in label-free microscopy images. - # datasets.get_deepbacs_dataset( - # path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, - # raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, - # download=True, sampler=MinInstanceSampler(min_num_instances=4, min_size=10) - # ), - # # cell segmentation in confocal microscopy images. - # datasets.get_plantseg_dataset( - # path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, - # patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), - # sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, - # label_transform=ResizeLabelTrafo(patch_shape, min_size=10), - # ), - # # cell segmentation in multi-modal microscopy images. - # datasets.get_neurips_cellseg_supervised_dataset( - # root=os.path.join(input_path, "neurips_cellseg"), split=split_choice, label_dtype=label_dtype, - # patch_shape=patch_shape, raw_transform=_to_8bit, label_transform=_get_label_transform(), - # sampler=MinInstanceSampler(min_num_instances=3, min_size=10), download=True, - # ), - # # nuclei segmentation in fluorescence microscopy images. - # datasets.get_dsb_dataset( - # path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", - # patch_shape=patch_shape, label_transform=_get_label_transform(), sampler=sampler, - # label_dtype=label_dtype, download=True, raw_transform=_identity, - # ), - # # nuclei segmentation in fluorescence microscopy images. - # datasets.get_dynamicnuclearnet_dataset( - # path=os.path.join(input_path, "dynamicnuclearnet"), patch_shape=patch_shape, download=True, sampler=sampler, - # split=split_choice, n_samples=500 if split_choice == "train" else 100, label_dtype=label_dtype, - # raw_transform=_to_8bit, label_transform=_get_label_transform(), - # ), + # bacteria segmentation in label-free microscopy images. + datasets.get_deepbacs_dataset( + path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, + raw_transform=_to_8bit, label_transform=_get_label_transform(), label_dtype=label_dtype, + download=True, sampler=MinInstanceSampler(min_num_instances=4, min_size=10) + ), + # cell segmentation in confocal microscopy images. + datasets.get_plantseg_dataset( + path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, + patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), + sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, + label_transform=ResizeLabelTrafo(patch_shape, min_size=10), + ), + # cell segmentation in multi-modal microscopy images. + datasets.get_neurips_cellseg_supervised_dataset( + root=os.path.join(input_path, "neurips_cellseg"), split=split_choice, label_dtype=label_dtype, + patch_shape=patch_shape, raw_transform=_to_8bit, label_transform=_get_label_transform(), + sampler=MinInstanceSampler(min_num_instances=3, min_size=10), download=True, + ), + # nuclei segmentation in fluorescence microscopy images. + datasets.get_dsb_dataset( + path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", + patch_shape=patch_shape, label_transform=_get_label_transform(), sampler=sampler, + label_dtype=label_dtype, download=True, raw_transform=_identity, + ), + # nuclei segmentation in fluorescence microscopy images. + datasets.get_dynamicnuclearnet_dataset( + path=os.path.join(input_path, "dynamicnuclearnet"), patch_shape=patch_shape, download=True, sampler=sampler, + split=split_choice, n_samples=500 if split_choice == "train" else 100, label_dtype=label_dtype, + raw_transform=_to_8bit, label_transform=_get_label_transform(), + ), # cell segmentation in multiple microscopy imaging modalities. - # TODO - # datasets.get_cellpose_dataset( - # path=os.path.join(input_path, "cellpose"), patch_shape=patch_shape, choice="cyto", raw_transform=_identity, - # download=True, split=split_choice if split_choice == "train" else "test", sampler=sampler, - # label_dtype=label_dtype, label_transform=_get_label_transform(), - # ), + datasets.get_cellpose_dataset( + path=os.path.join(input_path, "cellpose"), patch_shape=patch_shape, choice="cyto", sampler=sampler, + download=True, split=split_choice if split_choice == "train" else "test", label_dtype=label_dtype, + label_transform=_get_label_transform(), raw_transform=_cellpose_raw_trafo, + ), # bacteria segmentation in phase contrast and fluorescence microscopy images. # worm segmentation in brightfield microscopy images. - # datasets.get_omnipose_dataset( - # path=os.path.join(input_path, "omnipose"), patch_shape=patch_shape, download=True, - # split=split_choice if split_choice == "train" else "test", sampler=sampler, - # label_dtype=label_dtype, raw_transform=_to_8bit, label_transform=_get_label_transform(), - # ), + datasets.get_omnipose_dataset( + path=os.path.join(input_path, "omnipose"), patch_shape=patch_shape, download=True, + split=split_choice if split_choice == "train" else "test", sampler=sampler, + label_dtype=label_dtype, raw_transform=_to_8bit, label_transform=_get_label_transform(), + ), # organoid segmentation in brightfield microscopy images. - # TODO - # datasets.get_orgasegment_dataset( - # path=os.path.join(input_path, "orgasegment"), patch_shape=patch_shape, download=True, split=split_choice, - # raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, sampler=sampler, - # ), + datasets.get_orgasegment_dataset( + path=os.path.join(input_path, "orgasegment"), patch_shape=patch_shape, download=True, split=split_choice, + raw_transform=_identity, label_transform=_get_label_transform(), label_dtype=label_dtype, sampler=sampler, + ), ] # Add LIVECell dataset: cell segmentation for phase contrast microscopy images. - # _datasets.extend(get_livecell_datasets()) + _datasets.extend(get_livecell_datasets()) # Add EmbedSeg datasets: cell and nuclei segmentation for fluorescence microscopy images. - # _datasets.extend(get_embedseg_datasets()) + _datasets.extend(get_embedseg_datasets()) # Add YeaZ datasets: yeast segmentation for brightfield and phase contrast microscopy images. - # _datasets.extend(get_yeaz_dataset()) + _datasets.extend(get_yeaz_dataset()) # Add CVZ Fluo datasets: cell and nuclei segmentation for fluorescence microscopy images. - # TODO - # _datasets.extend(get_cvz_dataset("cell")) - # _datasets.extend(get_cvz_dataset("dapi")) + _datasets.extend(get_cvz_dataset("cell")) + _datasets.extend(get_cvz_dataset("dapi")) # Add CTC datasets: cell segmentation for various - # if split_choice == "train": # NOTE: We use CTC only for training. - # _datasets.extend(get_ctc_datasets()) + if split_choice == "train": # NOTE: We use CTC only for training. + _datasets.extend(get_ctc_datasets()) generalist_dataset = ConcatDataset(*_datasets) # Increasing the sampling attempts for the NeurIPS CellSeg dataset. - # generalist_dataset.datasets[3].max_sampling_attempts = 5000 + generalist_dataset.datasets[3].max_sampling_attempts = 5000 return generalist_dataset @@ -233,9 +253,4 @@ def get_generalist_lm_loaders(input_path, patch_shape): train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) - from torch_em.util.debug import check_loader - check_loader(train_loader, 16) - - breakpoint() - return train_loader, val_loader diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 35a767c62..bba7f535a 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -262,7 +262,7 @@ def __call__(self, raw): raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. if self.do_rescaling: - # NOTE: Below is done for TissueNet: to work with the valid channels. + # NOTE: Below is done for TissueNet: to work with the valid channels (i.e. the first and second channels). raw = normalize_percentile(raw, axis=(0, 1)) raw = normalize(raw) raw = raw * 255 From a6d64a14d6d667024cf99ebb86db4239b238a332 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 1 Jan 2025 10:50:56 +0100 Subject: [PATCH 10/16] Update slurm submission scripts --- .../light_microscopy/obtain_lm_datasets.py | 3 +- .../light_microscopy/train_lm_generalist.py | 1 + finetuning/run_all_finetuning.py | 64 ++++++++++--------- micro_sam/training/training.py | 6 +- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 31f5e513d..9138ab9ad 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -36,8 +36,9 @@ def _identity(x): def _cellpose_raw_trafo(x): """Transforms input images to desired format. + NOTE: The input channel logic is arranged a bit strangely in `cyto` dataset. - We take care of it here. + This function takes care of it here. """ r, g, b = x diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index 854f56fb6..2c17f337a 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -41,6 +41,7 @@ def finetune_lm_generalist(args): save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, verify_n_labels_in_loader=None, # NOTE: Verifies all labels in the loader(s). + box_distortion_factor=0.05, ) if args.export_path is not None: diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index 7562e3744..9a0bdf166 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -5,43 +5,37 @@ N_OBJECTS = { - "vit_t": 50, + # "vit_t": 50, "vit_b": 40, "vit_l": 30, - "vit_h": 25 + "vit_h": 25, } -def write_batch_script(out_path, _name, env_name, model_type, save_root): +def write_batch_script(out_path, _name, env_name, model_type, save_root, dry): "Writing scripts with different micro-sam finetunings." batch_script = f"""#!/bin/bash #SBATCH -t 14-00:00:00 -#SBATCH --mem 64G #SBATCH --nodes=1 #SBATCH --ntasks=1 -#SBATCH -p grete:shared -#SBATCH -G A100:1 -#SBATCH -A nim00007 +#SBATCH -p grete-h100:shared +#SBATCH -G H100:1 +#SBATCH -A gzz0001 #SBATCH -c 16 +#SBATCH --mem 64G #SBATCH --qos=14d -#SBATCH --constraint=80gb #SBATCH --job-name={os.path.split(_name)[-1]} -source activate {env_name} \n""" +source ~/.bashrc +micromamba activate {env_name} \n""" - # python script + # The python script python_script = f"python {_name}.py " + python_script += f"-s {save_root} " # The save root folder + python_script += f"-m {model_type} " # The name of the model configuration + python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " # The choice of the number of objects - # save root folder - python_script += f"-s {save_root} " - - # name of the model configuration - python_script += f"-m {model_type} " - - # choice of the number of objects - python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " - - # let's add the python script to the bash script + # Add the python script to the bash script batch_script += python_script _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh" @@ -49,7 +43,8 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root): f.write(batch_script) cmd = ["sbatch", _op] - subprocess.run(cmd) + if not dry: + subprocess.run(cmd) def get_batch_script_names(tmp_folder): @@ -101,17 +96,17 @@ def submit_slurm(args): write_batch_script( out_path=get_batch_script_names(tmp_folder), _name=script_name, - env_name="mobilesam" if model_type == "vit_t" else "sam", + env_name="mobilesam" if model_type == "vit_t" else "super", model_type=model_type, - save_root=args.save_root + save_root=args.save_root, + dry=args.dry, ) def main(args): - try: + tmp_dir = "./gpu_jobs" + if os.path.exists(tmp_dir): shutil.rmtree("./gpu_jobs") - except FileNotFoundError: - pass submit_slurm(args) @@ -119,8 +114,19 @@ def main(args): if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("-e", "--experiment_name", type=str, default=None) - parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/micro-sam/") - parser.add_argument("-m", "--model_type", type=str, default=None) + parser.add_argument( + "-e", "--experiment_name", type=str, default=None, help="The choice of experiment name.", + ) + parser.add_argument( + "-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam", + help="The path where to store the model checkpoints and logs.", + ) + parser.add_argument( + "-m", "--model_type", type=str, default=None, help="The choice of model type for Segment Anything model." + ) + parser.add_argument( + "--dry", action="store_true", help="Whether to submit the scripts to slurm or only store the scripts." + ) args = parser.parse_args() + main(args) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 79485b8d0..a3a35963c 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -188,6 +188,7 @@ def train_sam( peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, + box_distortion_factor: Optional[float] = 0.025, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -228,8 +229,9 @@ def train_sam( peft_kwargs: Keyword arguments for the PEFT wrapper class. verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders. - model_kwargs: Additional keyword arguments for the `util.get_sam_model`. ignore_warnings: Whether to ignore raised warnings. + box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. + model_kwargs: Additional keyword arguments for the `util.get_sam_model`. """ with _filter_warnings(ignore_warnings): @@ -251,7 +253,7 @@ def train_sam( ) # This class creates all the training data for a batch (inputs, prompts and labels). - convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) # Create the UNETR decoder (if train with it) and the optimizer. if with_segmentation_decoder: From 7159986d289baff62dbcec5718a7c07081a34ef7 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 1 Jan 2025 15:40:54 +0100 Subject: [PATCH 11/16] Update checks for instance ids --- .../light_microscopy/obtain_lm_datasets.py | 10 +++---- micro_sam/util.py | 26 +++++++------------ 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 9138ab9ad..83f6f6539 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -63,14 +63,10 @@ def get_concat_lm_datasets(input_path, patch_shape, split_choice): label_dtype = torch.float32 sampler = MinInstanceSampler(min_size=10) - def _get_label_transform(min_size=10): + def _get_label_transform(): label_transform = PerObjectDistanceTransform( - distances=True, - boundary_distances=True, - directed_distances=False, - foreground=True, - instances=True, - min_size=min_size + distances=True, boundary_distances=True, directed_distances=False, + foreground=True, instances=True, min_size=10 ) return label_transform diff --git a/micro_sam/util.py b/micro_sam/util.py index 6489188c7..741eae9cc 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -868,10 +868,7 @@ def precompute_image_embeddings( def set_precomputed( - predictor: SamPredictor, - image_embeddings: ImageEmbeddings, - i: Optional[int] = None, - tile_id: Optional[int] = None, + predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, ) -> SamPredictor: """Set the precomputed image embeddings for a predictor. @@ -938,8 +935,7 @@ def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: def get_centers_and_bounding_boxes( - segmentation: np.ndarray, - mode: str = "v" + segmentation: np.ndarray, mode: str = "v" ) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: """Returns the center coordinates of the foreground instances in the ground-truth. @@ -969,11 +965,7 @@ def get_centers_and_bounding_boxes( return center_coordinates, bbox_coordinates -def load_image_data( - path: str, - key: Optional[str] = None, - lazy_loading: bool = False -) -> np.ndarray: +def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: """Helper function to load image data from file. Args: @@ -994,10 +986,7 @@ def load_image_data( return image_data -def segmentation_to_one_hot( - segmentation: np.ndarray, - segmentation_ids: Optional[np.ndarray] = None, -) -> torch.Tensor: +def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: """Convert the segmentation to one-hot encoded masks. Args: @@ -1012,7 +1001,12 @@ def segmentation_to_one_hot( n_ids = int(segmentation.max()) else: - assert segmentation_ids[0] != 0, "No objects were found." + msg = "No foreground objects were found." + if len(segmentation_ids) == 0: # The list should not be completely empty. + raise AssertionError(msg) + + if segmentation_ids[0] == 0: # The list should not have zero-only. + raise AssertionError(msg) # the segmentation ids have to be sorted segmentation_ids = np.sort(segmentation_ids) From e154d4f9efb823dd53ceb289366852416b38aa42 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 1 Jan 2025 15:50:56 +0100 Subject: [PATCH 12/16] Update submission scripts --- finetuning/run_all_finetuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index 9a0bdf166..e3d3d3f87 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -8,7 +8,7 @@ # "vit_t": 50, "vit_b": 40, "vit_l": 30, - "vit_h": 25, + # "vit_h": 25, } From efb36c6a6adb2ac18d34205f58a7cb12e08638e8 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 4 Jan 2025 13:41:58 +0100 Subject: [PATCH 13/16] Update raw trafo for 8 bit inputs --- .../light_microscopy/obtain_lm_datasets.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 83f6f6539..22c086bac 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -17,13 +17,10 @@ def _to_8bit(raw): "Ensures three channels for inputs and rescale them to 8 bit." - if raw.ndim == 2: - raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. - else: - if raw.shape[0] != 3: - assert raw.shape[0] == 1, raw.shape - raw = np.concatenate([raw] * 3, axis=0) + if raw.ndim == 3 and raw.shape[0] == 1: # If the inputs have 1 channel, we triplicate it. + raw = np.concatenate([raw] * 3, axis=0) + raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. raw = normalize(raw) * 255 return raw @@ -65,8 +62,7 @@ def get_concat_lm_datasets(input_path, patch_shape, split_choice): def _get_label_transform(): label_transform = PerObjectDistanceTransform( - distances=True, boundary_distances=True, directed_distances=False, - foreground=True, instances=True, min_size=10 + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, ) return label_transform @@ -149,7 +145,7 @@ def get_ctc_datasets(): datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, raw_channel="rgb", label_channel="cell", raw_transform=ResizeRawTrafo((3, *patch_shape), do_rescaling=True), - label_transform=ResizeLabelTrafo(patch_shape, min_size=10), sampler=sampler, label_dtype=label_dtype, + label_transform=ResizeLabelTrafo(patch_shape), sampler=sampler, label_dtype=label_dtype, n_samples=500 if split_choice == "train" else 100, ), # bacteria segmentation in label-free microscopy images. @@ -163,7 +159,7 @@ def get_ctc_datasets(): path=os.path.join(input_path, "plantseg"), name="root", n_samples=500 if split_choice == "train" else 100, patch_shape=(1, *patch_shape), download=True, ndim=2, raw_transform=ResizeRawTrafo((3, *patch_shape)), sampler=MinInstanceSampler(min_num_instances=4, min_size=10), split=split_choice, label_dtype=label_dtype, - label_transform=ResizeLabelTrafo(patch_shape, min_size=10), + label_transform=ResizeLabelTrafo(patch_shape), ), # cell segmentation in multi-modal microscopy images. datasets.get_neurips_cellseg_supervised_dataset( From ceaee5b12ea168527e06daf0621acde5f7c8a49a Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 Jan 2025 13:29:33 +0100 Subject: [PATCH 14/16] Add supports for training SAM from scratch --- .../light_microscopy/train_lm_generalist.py | 10 ++++++---- micro_sam/training/training.py | 3 +++ micro_sam/training/util.py | 12 +++++++----- micro_sam/util.py | 17 ++++++++++------- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index 2c17f337a..e53d42046 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -42,6 +42,7 @@ def finetune_lm_generalist(args): scheduler_kwargs=scheduler_kwargs, verify_n_labels_in_loader=None, # NOTE: Verifies all labels in the loader(s). box_distortion_factor=0.05, + load_weights=(not args.from_scratch), ) if args.export_path is not None: @@ -49,9 +50,7 @@ def finetune_lm_generalist(args): "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" ) export_custom_sam_model( - checkpoint_path=checkpoint_path, - model_type=model_type, - save_path=args.export_path, + checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path ) @@ -63,7 +62,7 @@ def main(): ) parser.add_argument( "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + help="The model type to use for fine-tuning. Either 'vit_t', 'vit_b', 'vit_l' or 'vit_h'." ) parser.add_argument( "--save_root", "-s", default=None, @@ -80,6 +79,9 @@ def main(): parser.add_argument( "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) + parser.add_argument( + "--from_scratch", action="store_true", help="Whether to train Segment Anything model from scratch." + ) args = parser.parse_args() finetune_lm_generalist(args) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index a3a35963c..0c829d30a 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -189,6 +189,7 @@ def train_sam( ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, + load_weights: bool = True, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -231,6 +232,7 @@ def train_sam( By default, 50 batches of labels are verified from the dataloaders. ignore_warnings: Whether to ignore raised warnings. box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. + load_weights: Whether to initialize the model with pretrained parameter weights. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. """ with _filter_warnings(ignore_warnings): @@ -249,6 +251,7 @@ def train_sam( checkpoint_path=checkpoint_path, return_state=True, peft_kwargs=peft_kwargs, + load_weights=load_weights, **model_kwargs ) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index bba7f535a..a9d72b678 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -47,21 +47,22 @@ def get_trainable_sam_model( return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, + load_weights: bool = True, **model_kwargs ) -> TrainableSAM: """Get the trainable sam model. Args: - model_type: The segment anything model that should be finetuned. - The weights of this model will be used for initialization, unless a - custom weight file is passed via `checkpoint_path`. + model_type: The segment anything model that should be finetuned. The weights of this model + will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. device: The device to use for training. checkpoint_path: Path to a custom checkpoint from which to load the model weights. - freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder - By default nothing is frozen and the full model is updated. + freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and + `mask_decoder`. By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. peft_kwargs: Keyword arguments for the PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. + load_weights: Whether to initialize the model with pretrained parameter weights. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. Returns: @@ -76,6 +77,7 @@ def get_trainable_sam_model( return_sam=True, return_state=True, flexible_load_checkpoint=flexible_load_checkpoint, + load_weights=load_weights, **model_kwargs ) diff --git a/micro_sam/util.py b/micro_sam/util.py index d4906935f..ac83a4f8a 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -277,6 +277,7 @@ def get_sam_model( return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, + load_weights: bool = True, **model_kwargs, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. @@ -301,17 +302,18 @@ def get_sam_model( https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html Args: - model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default. To get a list of all available model names you can call `get_model_names`. device: The device for the model. If none is given will use GPU if available. checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to `model_type`. If given, `model_type` must match the architecture - corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder + corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. peft_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. + load_weights: Whether to initialize the model with pretrained parameter weights. model_kwargs: Additional parameters necessary to initialize the Segment Anything model. Returns: @@ -354,7 +356,7 @@ def get_sam_model( raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( - "mobile_sam is required for the vit-tiny." + "'mobile_sam' is required for the vit-tiny. " "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" ) @@ -378,10 +380,11 @@ def get_sam_model( sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. - if flexible_load_checkpoint: - sam = _handle_checkpoint_loading(sam, model_state) - else: - sam.load_state_dict(model_state) + if load_weights: + if flexible_load_checkpoint: + sam = _handle_checkpoint_loading(sam, model_state) + else: + sam.load_state_dict(model_state) sam.to(device=device) From cd1b098c77ca53654b848b5877840cb6e6a7a6fb Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 Jan 2025 22:50:44 +0100 Subject: [PATCH 15/16] Improve logic for raw trafo and checking foreground objects --- .../light_microscopy/obtain_lm_datasets.py | 7 ++-- micro_sam/training/util.py | 35 ++++++++++++------- micro_sam/util.py | 6 ++-- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 22c086bac..f20eefde8 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -144,9 +144,10 @@ def get_ctc_datasets(): # cell segmentation in tissue microscopy images. datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, - raw_channel="rgb", label_channel="cell", raw_transform=ResizeRawTrafo((3, *patch_shape), do_rescaling=True), - label_transform=ResizeLabelTrafo(patch_shape), sampler=sampler, label_dtype=label_dtype, - n_samples=500 if split_choice == "train" else 100, + raw_channel="rgb", label_channel="cell", label_transform=ResizeLabelTrafo(patch_shape), + # NOTE: Below is done for TissueNet: to work with the valid channels (i.e. the first and second channels). + raw_transform=ResizeRawTrafo((3, *patch_shape), do_rescaling=True, valid_channels=(0, 1)), + n_samples=500 if split_choice == "train" else 100, sampler=sampler, label_dtype=label_dtype, ), # bacteria segmentation in label-free microscopy images. datasets.get_deepbacs_dataset( diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index a9d72b678..1097efa4e 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -1,6 +1,6 @@ import os from math import ceil, floor -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Tuple import numpy as np @@ -188,11 +188,13 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): get_points = True # keeping the solution open by checking for deterministic/dynamic choice of point prompts - prompt_generator = PointAndBoxPromptGenerator(n_positive_points=n_pos, - n_negative_points=n_neg, - dilation_strength=self.dilation_strength, - get_box_prompts=get_boxes, - get_point_prompts=get_points) + prompt_generator = PointAndBoxPromptGenerator( + n_positive_points=n_pos, + n_negative_points=n_neg, + dilation_strength=self.dilation_strength, + get_box_prompts=get_boxes, + get_point_prompts=get_points + ) batched_inputs = [] batched_sampled_cell_ids_list = [] @@ -218,6 +220,7 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): batched_input["boxes"] = self.transform.apply_boxes_torch( box_prompts, original_size=gt.shape[-2:] ) if self.transform is not None else box_prompts + if get_points: batched_input["point_coords"] = self.transform.apply_coords_torch( point_prompts, original_size=gt.shape[-2:] @@ -255,17 +258,23 @@ def normalize_to_8bit(raw): class ResizeRawTrafo: - def __init__(self, desired_shape, do_rescaling=False, padding="constant"): + def __init__( + self, + desired_shape: Tuple[int, ...], + do_rescaling: bool = False, + valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, + padding: str = "constant" + ): self.desired_shape = desired_shape - self.padding = padding self.do_rescaling = do_rescaling + self.valid_channels = valid_channels + self.padding = padding def __call__(self, raw): raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. if self.do_rescaling: - # NOTE: Below is done for TissueNet: to work with the valid channels (i.e. the first and second channels). - raw = normalize_percentile(raw, axis=(0, 1)) + raw = normalize_percentile(raw, axis=self.valid_channels) raw = normalize(raw) raw = raw * 255 @@ -280,10 +289,12 @@ def __call__(self, raw): class ResizeLabelTrafo: - def __init__(self, desired_shape, padding="constant", min_size=0): + def __init__( + self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant", + ): self.desired_shape = desired_shape - self.padding = padding self.min_size = min_size + self.padding = padding def __call__(self, labels): distance_trafo = PerObjectDistanceTransform( diff --git a/micro_sam/util.py b/micro_sam/util.py index ac83a4f8a..867a47ec6 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -995,10 +995,10 @@ def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional else: msg = "No foreground objects were found." if len(segmentation_ids) == 0: # The list should not be completely empty. - raise AssertionError(msg) + raise RuntimeError(msg) - if segmentation_ids[0] == 0: # The list should not have zero-only. - raise AssertionError(msg) + if 0 in segmentation_ids: # The list should not have 'zero' as a value. + raise RuntimeError(msg) # the segmentation ids have to be sorted segmentation_ids = np.sort(segmentation_ids) From 2beebdadba18ec4034d6bcaf3abb232774398dae Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 Jan 2025 23:30:35 +0100 Subject: [PATCH 16/16] Replace inplace tensor operations for fact --- micro_sam/models/peft_sam.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index d72295c61..2fbb4c3d5 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -114,8 +114,13 @@ def forward(self, x): new_v = self.FacTv(new_v) # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. - qkv[:, :, :, : self.dim] += new_q - qkv[:, :, :, -self.dim:] += new_v + qkv = torch.cat( + [ + qkv[:, :, :, :self.dim] + new_q, # replacing new q values + qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical + qkv[:, :, :, -self.dim:] + new_v # replacing new v values + ], dim=-1 + ) return qkv