Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HED augmentations for digital pathology image #649

Merged
merged 57 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
16e7ca8
Update __init__.py
Geeks-Sid May 15, 2023
177565b
Update rgb_augs.py with augmentorbase
Geeks-Sid May 15, 2023
884365e
added sigma range and its checker
Geeks-Sid May 16, 2023
13dba52
Finished augmentor
Geeks-Sid May 16, 2023
a893604
Update rgb_augs.py to apply hed_transform
Geeks-Sid May 16, 2023
4e91c49
Update rgb_augs.py
Geeks-Sid May 16, 2023
bedef40
Create exceptions.py
Geeks-Sid May 16, 2023
2759ad0
Update parseConfig.py
Geeks-Sid May 16, 2023
eed74b5
Merge branch 'mlcommons:master' into hed_adv
Geeks-Sid May 16, 2023
22aff7d
update code fix
Geeks-Sid May 16, 2023
f4039e3
updated tests
Geeks-Sid May 16, 2023
787feae
fix augs import
Geeks-Sid May 16, 2023
1b27782
updated test to cover
Geeks-Sid May 16, 2023
1922af6
added testfull
Geeks-Sid May 16, 2023
29b7e6f
Merge branch 'mlcommons:master' into hed_adv
Geeks-Sid May 16, 2023
e025af7
fixed parseconfig
Geeks-Sid May 16, 2023
16cb62b
final fix
Geeks-Sid May 16, 2023
66bb31d
Merge branch 'master' into hed_adv
Geeks-Sid May 18, 2023
2f94198
some changes but still bugged
Geeks-Sid May 18, 2023
f98e2fa
Merge branch 'master' into hed_adv
Geeks-Sid May 19, 2023
235ce38
final fixes
Geeks-Sid May 19, 2023
606e855
Merge remote-tracking branch 'origin/hed_adv' into hed_adv
Geeks-Sid May 19, 2023
f2752ca
removed incorrect printing
Geeks-Sid May 19, 2023
f519663
Merge branch 'master' into hed_adv
Geeks-Sid May 22, 2023
85b5454
tests now pass. Find incorrect tests.
Geeks-Sid May 22, 2023
af34a05
code should cover now
Geeks-Sid May 23, 2023
36910e3
Merge branch 'master' into hed_adv
Geeks-Sid May 23, 2023
d579ec6
should run tests
Geeks-Sid May 23, 2023
5032c11
Updates for codacy, might need to cover tests
Geeks-Sid May 23, 2023
d5266b1
fix error in config_options
Geeks-Sid May 23, 2023
3f37af8
fix condition flips
Geeks-Sid May 23, 2023
27c1446
Merge branch 'master' into hed_adv
Geeks-Sid Jun 4, 2023
650a59d
Merge branch 'master' into hed_adv
Geeks-Sid Jun 13, 2023
1f76fa3
Merge branch 'master' into hed_adv
sarthakpati Jul 4, 2023
5ae8842
Merge branch 'mlcommons:master' into hed_adv
Geeks-Sid Jul 6, 2023
9db86c9
fixed conditions for a check
Geeks-Sid Jul 6, 2023
b8ee481
Update GANDLF/data/augmentation/rgb_augs.py
sarthakpati Jul 6, 2023
523a936
Merge branch 'master' into hed_adv
sarthakpati Jul 7, 2023
a5f80ed
Merge branch 'master' into hed_adv
Geeks-Sid Jul 13, 2023
ad75970
test coverage increase
Geeks-Sid Jul 13, 2023
e6facef
Merge remote-tracking branch 'origin/hed_adv' into hed_adv
Geeks-Sid Jul 13, 2023
16c9a53
Merge branch 'master' into hed_adv
sarthakpati Jul 14, 2023
c7dca1d
putting default probability above all augs
sarthakpati Jul 18, 2023
a6cfe09
putting the usage in the same place
sarthakpati Jul 18, 2023
670ddb4
updated comment
sarthakpati Jul 18, 2023
f09b72a
transform is getting applied, now
sarthakpati Jul 18, 2023
41b9496
added comment where tests are failing
sarthakpati Jul 18, 2023
191292e
logic fixed, added more comments
sarthakpati Jul 18, 2023
5b85027
updated api reference for clarity
sarthakpati Jul 18, 2023
2e4f7da
needed for more
sarthakpati Jul 18, 2023
878980d
updated api reference for clarity
sarthakpati Jul 18, 2023
cc0a900
Update test_full.py
Geeks-Sid Jul 19, 2023
4fba52e
Merge branch 'master' into hed_adv
Geeks-Sid Jul 19, 2023
1416174
Merge branch 'hed_adv' of https://github.com/Geeks-Sid/GaNDLF into he…
sarthakpati Jul 20, 2023
7fe43db
Merge pull request #10 from sarthakpati/hed_augs_siddhesh
Geeks-Sid Jul 20, 2023
343a4c5
commented out light and heavy
sarthakpati Jul 21, 2023
1cc449a
commented out light and heavy
sarthakpati Jul 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions GANDLF/data/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
rotate_180,
)
from .rgb_augs import colorjitter_transform
from .hed_augs import hed_transform

# Defining a dictionary for augmentations - key is the string and the value is the augmentation object
global_augs_dict = {
Expand All @@ -35,4 +36,5 @@
"rotate_180": rotate_180,
"anisotropic": anisotropy,
"colorjitter": colorjitter_transform,
"hed_transform": hed_transform,
}
314 changes: 314 additions & 0 deletions GANDLF/data/augmentation/hed_augs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
from typing import Tuple, Union
import numpy as np
import torch
from skimage.color import rgb2hed, hed2rgb
from torchio.transforms.augmentation import RandomTransform
from torchio.transforms import IntensityTransform
from torchio import Subject


def hed_transform(parameters):
return RandomHEDTransform(
haematoxylin_sigma_range=parameters["haematoxylin_sigma_range"],
haematoxylin_bias_range=parameters["haematoxylin_bias_range"],
eosin_sigma_range=parameters["eosin_sigma_range"],
eosin_bias_range=parameters["eosin_bias_range"],
dab_sigma_range=parameters["dab_sigma_range"],
dab_bias_range=parameters["dab_bias_range"],
cutoff_range=parameters["cutoff_range"],
)


class AugmenterBase:
"""Base class for patch augmentation with a hed transform"""

def __init__(self, keyword):
"""
Args:
keyword (str): Short name for the transformation.
"""
self._keyword = keyword

## commented the following lines because the user is never given access to these
# @property
# def keyword(self):
# """Get the keyword for the augmenter."""
# return self._keyword

# def shapes(self, target_shapes):
# """Calculate the required shape of the input to achieve the target output shape."""
# return target_shapes

# def transform(self, patch):
# """Transform the given patch."""
# return patch

# def randomize(self):
# """Randomize the parameters of the augmenter."""
# return


class ColorAugmenterBase(AugmenterBase):
"""Base class for color patch augmentation."""

def __init__(self, keyword):
"""
Initialize the object.
Args:
keyword (str): Short name for the transformation.
"""

# Initialize the base class.
super().__init__(keyword=keyword)


class HedColorAugmenter(ColorAugmenterBase):
"""Apply color correction in HED color space on the RGB patch."""

def __init__(
self,
haematoxylin_sigma_range: Union[tuple, None],
haematoxylin_bias_range: Union[tuple, None],
eosin_sigma_range: Union[tuple, None],
eosin_bias_range: Union[tuple, None],
dab_sigma_range: Union[tuple, None],
dab_bias_range: Union[tuple, None],
cutoff_range: Union[tuple, None],
) -> ColorAugmenterBase:
"""
The following code is derived and inspired from the following sources: https://github.com/sebastianffx/stainlib.

Args:
haematoxylin_sigma_range (Union[tuple, None]): Adjustment range for the Haematoxylin channel from the [-1.0, 1.0] range where 0.0 means no change. For example (-0.1, 0.1).
haematoxylin_bias_range (Union[tuple, None]): Bias range for the Haematoxylin channel from the [-1.0, 1.0] range where 0.0 means no change. For example (-0.2, 0.2).
eosin_sigma_range (Union[tuple, None]): Adjustment range for the Eosin channel from the [-1.0, 1.0] range where 0.0 means no change.
eosin_bias_range (Union[tuple, None]): Bias range for the Eosin channel from the [-1.0, 1.0] range where 0.0 means no change.
dab_sigma_range (Union[tuple, None]): Adjustment range for the DAB channel from the [-1.0, 1.0] range where 0.0 means no change.
dab_bias_range (Union[tuple, None]): Bias range for the DAB channel from the [-1.0, 1.0] range where 0.0 means no change.
cutoff_range (Union[tuple, None]): Patches with mean value outside the cutoff interval will not be augmented. Values from the [0.0, 1.0] range. The RGB channel values are from the same range.

Returns:
ColorAugmenterBase: _description_
"""

# Initialize base class.
super().__init__(keyword="hed_color")

# Initialize members.
self._sigma_ranges = None # Configured sigma ranges for H, E, and D channels.
self._bias_ranges = None # Configured bias ranges for H, E, and D channels.
self._cutoff_range = None # Cutoff interval.
self._sigmas = None # Randomized sigmas for H, E, and D channels.
self._biases = None # Randomized biases for H, E, and D channels.

# Save configuration.
self._setsigmaranges(
haematoxylin_sigma_range=haematoxylin_sigma_range,
eosin_sigma_range=eosin_sigma_range,
dab_sigma_range=dab_sigma_range,
)
self._setbiasranges(
haematoxylin_bias_range=haematoxylin_bias_range,
eosin_bias_range=eosin_bias_range,
dab_bias_range=dab_bias_range,
)
self._setcutoffrange(cutoff_range=cutoff_range)

def _setsigmaranges(
self,
haematoxylin_sigma_range: Union[tuple, None],
eosin_sigma_range: Union[tuple, None],
dab_sigma_range: Union[tuple, None],
):
"""
Set the sigma intervals.

Args:
haematoxylin_sigma_range (Union[tuple, None]): Adjustment range for the Haematoxylin channel.
eosin_sigma_range (Union[tuple, None]): Adjustment range for the Eosin channel.
dab_sigma_range (Union[tuple, None]): Adjustment range for the DAB channel.
"""

def check_sigma_range(name, given_range):
assert given_range is None or (
len(given_range) == 2
and given_range[0] < given_range[1]
and -1.0 <= given_range[0] <= 1.0
and -1.0 <= given_range[1] <= 1.0
), f"Invalid range for {name}: {given_range}"

check_sigma_range("Haematoxylin Sigma", haematoxylin_sigma_range)
check_sigma_range("Eosin Sigma", eosin_sigma_range)
check_sigma_range("Dab Sigma", dab_sigma_range)

self._sigma_ranges = [
haematoxylin_sigma_range,
eosin_sigma_range,
dab_sigma_range,
]
self._sigmas = [
haematoxylin_sigma_range[0]
if haematoxylin_sigma_range is not None
else 0.0,
eosin_sigma_range[0] if eosin_sigma_range is not None else 0.0,
dab_sigma_range[0] if dab_sigma_range is not None else 0.0,
]

def _setbiasranges(
self,
haematoxylin_bias_range: Union[tuple, None],
eosin_bias_range: Union[tuple, None],
dab_bias_range: Union[tuple, None],
):
"""
Set the bias intervals.

Args:
haematoxylin_bias_range (Union[tuple, None]): Bias range for the Haematoxylin channel.
eosin_bias_range (Union[tuple, None]): Bias range for the Eosin channel.
dab_bias_range (Union[tuple, None]): Bias range for the DAB channel.
"""

def check_bias_range(name, given_range):
assert given_range is None or (
len(given_range) != 2
or given_range[0] < given_range[1]
or -1.0 <= given_range[0]
or given_range[1] <= 1.0
), f"Invalid range for {name}: {given_range}"

check_bias_range("Haematoxylin Bias", haematoxylin_bias_range)
check_bias_range("Eosin Bias", eosin_bias_range)
check_bias_range("Dab Bias", dab_bias_range)

self._bias_ranges = [haematoxylin_bias_range, eosin_bias_range, dab_bias_range]
self._biases = [
haematoxylin_bias_range[0] if haematoxylin_bias_range is not None else 0.0,
eosin_bias_range[0] if eosin_bias_range is not None else 0.0,
dab_bias_range[0] if dab_bias_range is not None else 0.0,
]

def _setcutoffrange(self, cutoff_range: Union[tuple, None]):
"""
Set the cutoff value. Patches with mean value outside the cutoff interval will not be augmented.

Args:
cutoff_range (Union[tuple, None]): Cutoff range for mean value.
"""

def check_cutoff_range(name, given_range):
assert given_range is None or (
len(given_range) != 2
or given_range[0] < given_range[1]
or 0 <= given_range[0]
or given_range[1] <= 1.0
), f"Invalid range for {name}: {given_range}"

check_cutoff_range("Cutoff", cutoff_range)

self._cutoff_range = cutoff_range if cutoff_range is not None else [0.0, 1.0]

## commented the following lines because the user is never given access to this function
# def randomize(self):
# """Randomize the parameters of the augmenter."""

# # Randomize sigma and bias for each channel.
# self._sigmas = [
# np.random.uniform(sigma_range[0], sigma_range[1]) if sigma_range else 1.0
# for sigma_range in self._sigma_ranges
# ]
# self._biases = [
# np.random.uniform(bias_range[0], bias_range[1]) if bias_range else 0.0
# for bias_range in self._bias_ranges
# ]

def transform(self, patch: torch.Tensor) -> torch.Tensor:
"""
Apply color deformation on the patch.

Args:
patch (torch.Tensor): The input patch to transform.

Returns:
torch.Tensor: The transformed patch.
"""

current_patch = patch.numpy().astype(np.float32)
patch_mean = (
np.mean(current_patch) / 255.0
if current_patch.dtype.kind != "f"
else np.mean(current_patch)
)

if self._cutoff_range[0] <= patch_mean <= self._cutoff_range[1]:
# Convert the image patch to HED color coding.
patch_hed = rgb2hed(current_patch)

# Augment the channels.
for i in range(3):
if self._sigmas[i] != 0.0:
patch_hed[..., i] *= 1.0 + self._sigmas[i]
if self._biases[i] != 0.0:
patch_hed[..., i] += self._biases[i]

# Convert back to RGB color coding.
patch_transformed = hed2rgb(patch_hed)
patch_transformed = np.clip(patch_transformed, 0.0, 1.0)

# Convert back to integral data type if the input was also integral.
if current_patch.dtype.kind != "f":
patch_transformed *= 255.0
patch_transformed = patch_transformed.astype(np.uint8)

return patch_transformed

# The image patch is outside the cutoff interval.
return patch


class RandomHEDTransform(RandomTransform, IntensityTransform):
def __init__(
self,
haematoxylin_sigma_range: Union[float, Tuple[float, float]] = 0.1,
haematoxylin_bias_range: Union[float, Tuple[float, float]] = 0.1,
eosin_sigma_range: Union[float, Tuple[float, float]] = 0.1,
eosin_bias_range: Union[float, Tuple[float, float]] = 0.1,
dab_sigma_range: Union[float, Tuple[float, float]] = 0.1,
dab_bias_range: Union[float, Tuple[float, float]] = 0.1,
cutoff_range: Union[float, Tuple[float, float]] = (0, 1),
**kwargs,
):
super().__init__(**kwargs)
self.transform_object = HedColorAugmenter(
haematoxylin_sigma_range=haematoxylin_sigma_range,
haematoxylin_bias_range=haematoxylin_bias_range,
eosin_sigma_range=eosin_sigma_range,
eosin_bias_range=eosin_bias_range,
dab_sigma_range=dab_sigma_range,
dab_bias_range=dab_bias_range,
cutoff_range=cutoff_range,
)

def apply_transform(self, subject: Subject) -> Subject:
# Process only if the image is RGB
for _, image in self.get_images_dict(subject).items():
if image.data.shape[-1] == 1:
if image.data.ndim == 4:
tensor = image.data[..., 0]
# put channel to last axis (needed for colorconv to work)
tensor = tensor.permute(2, 1, 0)

# Apply transform
transformed_tensor = self.transform_object.transform(tensor)

# Convert tensor back to tensor data
transformed_data = (
torch.from_numpy(transformed_tensor)
.permute(2, 0, 1)
.unsqueeze(-1)
)

# Update image data
image.set_data(transformed_data)

return subject
1 change: 0 additions & 1 deletion GANDLF/data/augmentation/rgb_augs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torchvision.transforms import ColorJitter
from typing import Tuple, Union

from torchio.transforms.augmentation import RandomTransform
from torchio.transforms import IntensityTransform
from torchio import Subject
Expand Down
Loading
Loading