Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mirroring test-time augmentation #96

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
82 changes: 59 additions & 23 deletions monai/run_inference_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

from monai.inferers import sliding_window_inference
from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch)
from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd,
LoadImaged, NormalizeIntensityd, EnsureChannelFirstd,
DivisiblePadd, Orientationd, ResizeWithPadOrCropd)
import monai.transforms as transforms
from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
Expand Down Expand Up @@ -72,6 +70,8 @@ def get_parser():
' Default: 64x192x-1')
parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"],
help='Device to run inference on. Default: cpu')
parser.add_argument('--use-tta', action='store_true',
help='Use test-time augmentation (TTA) i.e mirroring on all axes. Default: False')
naga-karthik marked this conversation as resolved.
Show resolved Hide resolved

return parser

Expand All @@ -80,16 +80,40 @@ def get_parser():
# Test-time Transforms
# ===========================================================================
def inference_transforms_single_image(crop_size):
return Compose([
LoadImaged(keys=["image"], image_only=False),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RPI"),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)),
ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,),
DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5)
NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
return transforms.Compose([
transforms.LoadImaged(keys=["image"], image_only=False),
transforms.EnsureChannelFirstd(keys=["image"]),
transforms.Orientationd(keys=["image"], axcodes="RPI"),
transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)),
transforms.ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,),
transforms.DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5)
transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
])

# # ===========================================================================
# # Test-time Augmentation Transforms
# # (_could_ use same transforms as in train_transforms but not using for now)
# # (TODO: revisit this in the future)
# # ===========================================================================
# def inference_transforms_tta(crop_size):
# return transforms.Compose([
# transforms.LoadImaged(keys=["image"], image_only=False),
# transforms.EnsureChannelFirstd(keys=["image"]),
# transforms.Orientationd(keys=["image"], axcodes="RPI"),
# transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)),
# transforms.ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,),
# transforms.DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5)
# # use the same transforms as in train_transforms
# # transforms.RandAffined(keys=["image"], mode=(2), prob=0.9,
# # rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians
# # scale_range=(-0.2, 0.2),
# # translate_range=(-0.1, 0.1)),
# # transforms.RandHistogramShiftd(keys=["image"], prob=1.0, num_control_points=10),
# # transforms.RandFlipd(keys=["image"], prob=1, spatial_axis=0),
# # transforms.RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=1.0), # this is monai's RandomGamma
# transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
# ])


# ===========================================================================
# Model utils
Expand Down Expand Up @@ -177,7 +201,7 @@ def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, d
# ===========================================================================
# Prepare temporary dataset for inference
# ===========================================================================
def prepare_data(path_image, path_out, crop_size=(64, 160, 320)):
def prepare_data(path_image, path_out, crop_size=(64, 160, 320), tta=False):

# create a temporary datalist containing the image
# boiler plate keys to be defined in the MSD-style datalist
Expand Down Expand Up @@ -206,13 +230,15 @@ def prepare_data(path_image, path_out, crop_size=(64, 160, 320)):
test_files = load_decathlon_datalist(dataset, True, "test")

# define test transforms
if tta:
logger.info("Using test-time augmentation (mirroring across all 3 axes) ...")
transforms_test = inference_transforms_single_image(crop_size=crop_size)

# define post-processing transforms for testing; taken (with explanations) from
# https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66
test_post_pred = Compose([
EnsureTyped(keys=["pred"]),
Invertd(keys=["pred"], transform=transforms_test,
test_post_pred = transforms.Compose([
transforms.EnsureTyped(keys=["pred"]),
transforms.Invertd(keys=["pred"], transform=transforms_test,
orig_keys=["image"],
meta_keys=["pred_meta_dict"],
nearest_interp=False, to_tensor=True),
Expand Down Expand Up @@ -254,7 +280,7 @@ def main():
inference_roi_size = (64, 192, 320)

# define the dataset and dataloader
test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size)
test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size, tta=args.use_tta)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

# define model
Expand Down Expand Up @@ -288,12 +314,22 @@ def main():
net.to(DEVICE)
net.eval()

# run inference
batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian",
sw_batch_size=4, predictor=net, overlap=0.5, progress=False)

# take only the highest resolution prediction
batch["pred"] = batch["pred"][0]
# test-time augmentation
if args.use_tta:
# iterate over the x, y, z axes
batch["pred"] = torch.zeros_like(test_input)
for axis in range(3):
# flip the input, run inference and flip it back
batch["pred"] += torch.flip(sliding_window_inference(
torch.flip(test_input, dims=[axis]), inference_roi_size, mode="gaussian",
sw_batch_size=4, predictor=net, overlap=0.5, progress=False)[0], dims=[axis]
)
# average the prediction
batch["pred"] /= 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if averaging should be done after clamping ?
Because averaging is done before clamping, if a voxel is seen only once, its value will be 0.33 and therefore will be removed afterward at clamping (done at 0.5).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh man, this is an excellent catch! thanks for pointing this much important thing out! I'm going to fix it in the next commit!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 4beb690

else:
# run inference and take highest resolution prediction
batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian",
sw_batch_size=4, predictor=net, overlap=0.5, progress=False)[0]

# NOTE: monai's models do not normalize the output, so we need to do it manually
if bool(F.relu(batch["pred"]).max()):
Expand All @@ -318,7 +354,7 @@ def main():

naga-karthik marked this conversation as resolved.
Show resolved Hide resolved
# this takes about 0.25s on average on a CPU
# image saver class
pred_saver = SaveImage(
pred_saver = transforms.SaveImage(
output_dir=results_path, output_postfix="pred", output_ext=".nii.gz",
separate_folder=False, print_log=False)
# save the prediction
Expand Down