Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mirroring test-time augmentation #96

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
132 changes: 108 additions & 24 deletions monai/run_inference_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
import torch.nn as nn
import json
from time import time
import scipy.ndimage as ndimage

from monai.inferers import sliding_window_inference
from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch)
from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd,
LoadImaged, NormalizeIntensityd, EnsureChannelFirstd,
DivisiblePadd, Orientationd, ResizeWithPadOrCropd)
import monai.transforms as transforms
from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
Expand All @@ -29,6 +28,7 @@
INIT_FILTERS=32
ENABLE_DS = True

# WARNING: Do NOT modify this
nnunet_plans = {
"UNet_class_name": "PlainConvUNet",
"UNet_base_num_features": INIT_FILTERS,
Expand Down Expand Up @@ -72,6 +72,11 @@ def get_parser():
' Default: 64x192x-1')
parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"],
help='Device to run inference on. Default: cpu')
parser.add_argument('--use-tta', action='store_true',
help='Use test-time augmentation (TTA), i.e. mirroring across all 3 axes. Default: False')
parser.add_argument('--remove-small-objects', int, default=10,
help='Remove all unconnected objects smaller than the minimum specified size.'
'Defined as a percent of the total no. of voxels in the prediction. Default: 10(%)')

return parser

Expand All @@ -80,16 +85,40 @@ def get_parser():
# Test-time Transforms
# ===========================================================================
def inference_transforms_single_image(crop_size):
return Compose([
LoadImaged(keys=["image"], image_only=False),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RPI"),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)),
ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,),
DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5)
NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
return transforms.Compose([
transforms.LoadImaged(keys=["image"], image_only=False),
transforms.EnsureChannelFirstd(keys=["image"]),
transforms.Orientationd(keys=["image"], axcodes="RPI"),
transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)),
transforms.ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,),
transforms.DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5)
transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
])

# # ===========================================================================
# # Test-time Augmentation Transforms
# # (_could_ use same transforms as in train_transforms but not using for now)
# # (TODO: revisit this in the future)
# # ===========================================================================
# def inference_transforms_tta(crop_size):
# return transforms.Compose([
# transforms.LoadImaged(keys=["image"], image_only=False),
# transforms.EnsureChannelFirstd(keys=["image"]),
# transforms.Orientationd(keys=["image"], axcodes="RPI"),
# transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)),
# transforms.ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,),
# transforms.DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5)
# # use the same transforms as in train_transforms
# # transforms.RandAffined(keys=["image"], mode=(2), prob=0.9,
# # rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians
# # scale_range=(-0.2, 0.2),
# # translate_range=(-0.1, 0.1)),
# # transforms.RandHistogramShiftd(keys=["image"], prob=1.0, num_control_points=10),
# # transforms.RandFlipd(keys=["image"], prob=1, spatial_axis=0),
# # transforms.RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=1.0), # this is monai's RandomGamma
# transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
# ])


# ===========================================================================
# Model utils
Expand All @@ -105,6 +134,46 @@ def __call__(self, module):
module.bias = nn.init.constant_(module.bias, 0)


# ============================================================================
# Helper function(s)
# ============================================================================

def remove_small_objects(data, size_min_percent=10):
"""Removes all unconnected objects smaller than the minimum specified size.
(hence keeping only the largest objects)
Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/postprocessing.py#L224

Args:
data (ndarray): Input data.
size_min_percentage (int): Minimal size of objects to remove as a percent of the
total number of voxels.
e.g. size_min_percent=10 means that objects smaller than 10% of the total voxels will be removed.

Returns:
ndarray: Array with small objects.
"""

bin_structure = ndimage.generate_binary_structure(3, 2)

# get the total number of voxels annotated
n_voxels_total = np.count_nonzero(data)

# squeeze the first dimension (to be compatible with generate_binary_structure rank 3)
data = data.squeeze(axis=0)

data_label, n = ndimage.label(data, structure=bin_structure)

for idx in range(1, n + 1):
data_idx = (data_label == idx).astype(int)
n_nonzero = np.count_nonzero(data_idx)

#we only keep continuous objects that are larger than 10% of the total number of voxels
if n_nonzero < (size_min_percent / 100) * n_voxels_total:
data[data_label == idx] = 0

return data


# ============================================================================
# Define the network based on nnunet_plans dict
# ============================================================================
Expand Down Expand Up @@ -177,7 +246,7 @@ def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, d
# ===========================================================================
# Prepare temporary dataset for inference
# ===========================================================================
def prepare_data(path_image, path_out, crop_size=(64, 160, 320)):
def prepare_data(path_image, path_out, crop_size=(64, 160, 320), tta=False):

# create a temporary datalist containing the image
# boiler plate keys to be defined in the MSD-style datalist
Expand Down Expand Up @@ -206,13 +275,15 @@ def prepare_data(path_image, path_out, crop_size=(64, 160, 320)):
test_files = load_decathlon_datalist(dataset, True, "test")

# define test transforms
if tta:
logger.info("Using test-time augmentation (mirroring across all 3 axes) ...")
transforms_test = inference_transforms_single_image(crop_size=crop_size)

# define post-processing transforms for testing; taken (with explanations) from
# https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66
test_post_pred = Compose([
EnsureTyped(keys=["pred"]),
Invertd(keys=["pred"], transform=transforms_test,
test_post_pred = transforms.Compose([
transforms.EnsureTyped(keys=["pred"]),
transforms.Invertd(keys=["pred"], transform=transforms_test,
orig_keys=["image"],
meta_keys=["pred_meta_dict"],
nearest_interp=False, to_tensor=True),
Expand Down Expand Up @@ -254,7 +325,7 @@ def main():
inference_roi_size = (64, 192, 320)

# define the dataset and dataloader
test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size)
test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size, tta=args.use_tta)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

# define model
Expand Down Expand Up @@ -288,12 +359,22 @@ def main():
net.to(DEVICE)
net.eval()

# run inference
batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian",
sw_batch_size=4, predictor=net, overlap=0.5, progress=False)

# take only the highest resolution prediction
batch["pred"] = batch["pred"][0]
# test-time augmentation
if args.use_tta:
# iterate over the x, y, z axes
batch["pred"] = torch.zeros_like(test_input)
for axis in range(3):
# flip the input, run inference and flip it back
batch["pred"] += torch.flip(sliding_window_inference(
torch.flip(test_input, dims=[axis]), inference_roi_size, mode="gaussian",
sw_batch_size=4, predictor=net, overlap=0.5, progress=False)[0], dims=[axis]
)
# average the prediction
batch["pred"] /= 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if averaging should be done after clamping ?
Because averaging is done before clamping, if a voxel is seen only once, its value will be 0.33 and therefore will be removed afterward at clamping (done at 0.5).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh man, this is an excellent catch! thanks for pointing this much important thing out! I'm going to fix it in the next commit!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 4beb690

else:
# run inference and take highest resolution prediction
batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian",
sw_batch_size=4, predictor=net, overlap=0.5, progress=False)[0]

# NOTE: monai's models do not normalize the output, so we need to do it manually
if bool(F.relu(batch["pred"]).max()):
Expand All @@ -311,14 +392,17 @@ def main():
pred = torch.clamp(pred, 0.5, 1)
# set background values to 0
pred[pred <= 0.5] = 0


# remove small objects
pred = remove_small_objects(pred, size_min_percent=args.remove_small_objects)

# get subject name
subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "")
logger.info(f"Saving subject: {subject_name}")

naga-karthik marked this conversation as resolved.
Show resolved Hide resolved
# this takes about 0.25s on average on a CPU
# image saver class
pred_saver = SaveImage(
pred_saver = transforms.SaveImage(
output_dir=results_path, output_postfix="pred", output_ext=".nii.gz",
separate_folder=False, print_log=False)
# save the prediction
Expand Down