Skip to content

Commit

Permalink
add docstrings to filtering method
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Jun 3, 2024
1 parent 3da72f3 commit f23b9a0
Showing 1 changed file with 177 additions and 31 deletions.
208 changes: 177 additions & 31 deletions src/sparcscore/processing/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,80 @@ def filter(self, input_mask):
return self.update_mask(input_mask, self.ids_to_remove)

class MatchNucleusCytosolIds(BaseFilter):
"""
Filter class for matching nucleusIDs to their matching cytosol IDs and removing all classes from the given
segmentation masks that do not fullfill the filtering criteria.
Masks only pass filtering if both a nucleus and a cytosol mask are present and have an overlapping area
larger than the specified threshold. If the threshold is not specified, the default value is set to 0.5.
Parameters
----------
filtering_threshold : float, optional
The threshold for filtering cytosol IDs based on the proportion of overlapping area with the nucleus. Default is 0.5.
downsampling_factor : int, optional
The downsampling factor for the masks. Default is None.
erosion_dilation : bool, optional
Flag indicating whether to perform erosion and dilation on the masks during upscaling. Default is True.
smoothing_kernel_size : int, optional
The size of the smoothing kernel for upscaling. Default is 7.
*args
Additional positional arguments.
**kwargs
Additional keyword arguments.
Attributes
----------
filtering_threshold : float
The threshold for filtering cytosol IDs based on the proportion of overlapping area with the nucleus.
downsample : bool
Flag indicating whether downsampling is enabled.
downsampling_factor : int
The downsampling factor for the masks.
erosion_dilation : bool
Flag indicating whether to perform erosion and dilation on the masks during upscaling.
smoothing_kernel_size : int
The size of the smoothing kernel for upscaling.
nucleus_mask : numpy.ndarray
The nucleus mask.
cytosol_mask : numpy.ndarray
The cytosol mask.
nuclei_discard_list : list
A list of nucleus IDs to be discarded.
cytosol_discard_list : list
A list of cytosol IDs to be discarded.
nucleus_lookup_dict : dict
A dictionary mapping nucleus IDs to matched cytosol IDs after filtering.
Methods
-------
load_masks(nucleus_mask, cytosol_mask)
Load the nucleus and cytosol masks.
update_cytosol_mask(cytosol_mask)
Update the cytosol mask based on the matched nucleus-cytosol pairs.
update_masks()
Update the nucleus and cytosol masks after filtering.
match_nucleus_id(nucleus_id)
Match the given nucleus ID to a cytosol ID based on the overlapping area.
initialize_lookup_table()
Initialize the lookup table by matching all nucleus IDs to cytosol IDs.
count_cytosol_occurances()
Count the occurrences of each cytosol ID in the lookup table.
check_for_unassigned_cytosols()
Check for unassigned cytosol IDs in the cytosol mask.
identify_multinucleated_cells()
Identify and discard multinucleated cells from the lookup table.
cleanup_filtering_lists()
Cleanup the discard lists by removing duplicate entries.
cleanup_lookup_dictionary()
Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs.
generate_lookup_table(nucleus_mask, cytosol_mask)
Generate the lookup table by performing all necessary steps.
filter(nucleus_mask, cytosol_mask)
Filter the nucleus and cytosol masks based on the matching results.
"""

def __init__(
self,
filtering_threshold=0.5,
Expand All @@ -428,54 +502,87 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)

#set relevant parameters
self.filtering_threshold = filtering_threshold

#set up downsampling
#set downsampling parameters
if downsampling_factor is not None:
self.downsample = True
self.downsample = True
self.downsampling_factor = downsampling_factor
self.erosion_dilation = erosion_dilation
self.smoothing_kernel_size = smoothing_kernel_size
else:
self.downsample = False


#initialize placeholders for masks
self.nucleus_mask = None
self.cytosol_mask = None

# initialize datastructures for saving results
self._nucleus_lookup_dict = {}
self.nuclei_discard_list = []
self.cytosol_discard_list = []
self.nucleus_lookup_dict = None

def load_masks(self, nucleus_mask, cytosol_mask):
#masks are only loaded once (if already loaded nothing happens)
"""
Load the nucleus and cytosol masks into their placeholders.
This function only loads the masks (and downsamples them if necessary) if this has not already been performed.
Parameters
----------
nucleus_mask : numpy.ndarray
The nucleus mask.
cytosol_mask : numpy.ndarray
The cytosol mask.
"""
if self.downsample:
if self.nucleus_mask is None:
self.nucleus_mask = self.downsample_mask(nucleus_mask)
if self.cytosol_mask is None:
self.cytosol_mask = self.downsample_mask(cytosol_mask)
self.cytosol_mask = self.downsample_mask(cytosol_mask)
else:
if self.nucleus_mask is None:
self.nucleus_mask = nucleus_mask
if self.cytosol_mask is None:
self.cytosol_mask = cytosol_mask

def update_cytosol_mask(self, cytosol_mask):
# now we have all the nucleus cytosol pairs we can filter the masks
"""
Update the cytosol mask based on the matched nucleus-cytosol pairs.
Parameters
----------
cytosol_mask : numpy.ndarray
The cytosol mask.
Returns
-------
numpy.ndarray
The updated cytosol mask.
"""
updated = np.zeros_like(cytosol_mask, dtype=bool)

for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items():
# set the cytosol pixels to the nucleus_id if not previously updated
condition = np.logical_and(
cytosol_mask == cytosol_id, ~updated
)
cytosol_mask[condition] = nucleus_id
updated = np.logical_or(
updated, condition
)
return(cytosol_mask)
condition = np.logical_and(
cytosol_mask == cytosol_id, ~updated
)
cytosol_mask[condition] = nucleus_id
updated = np.logical_or(
updated, condition
)
return cytosol_mask

def update_masks(self):
"""
Update the nucleus and cytosol masks after filtering.
Returns
-------
tuple
A tuple containing the updated nucleus mask and cytosol mask.
"""
nucleus_mask = self.update_mask(self.nucleus_mask, self.nuclei_discard_list)
cytosol_mask = self.update_cytosol_mask(self.cytosol_mask)

Expand All @@ -491,32 +598,23 @@ def match_nucleus_id(self, nucleus_id):
Parameters
----------
nucleus_mask : numpy.ndarray
The nucleus mask.
cytosol_mask : numpy.ndarray
The cytosol mask.
nucleus_id : int
The nucleus ID to be matched.
Returns
-------
int
The matched cytosol ID.
int or None
The matched cytosol ID, or None if no match is found.
"""
# get the coordinates of the nucleus
nucleus_pixels = np.where(self.nucleus_mask == nucleus_id)

# check if those indices are not background in the cytosol mask
potential_cytosol = self.cytosol_mask[nucleus_pixels]

# if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus
if np.all(potential_cytosol != 0):
unique_cytosol, counts = np.unique(potential_cytosol, return_counts=True)
all_counts = np.sum(counts)
cytosol_proportions = counts / all_counts

if np.any(cytosol_proportions >= self.filtering_threshold):
# get the cytosol_id with max proportion
cytosol_id = unique_cytosol[
np.argmax(cytosol_proportions >= self.filtering_threshold)
]
Expand All @@ -534,38 +632,55 @@ def match_nucleus_id(self, nucleus_id):
return None

def initialize_lookup_table(self):
"""
Initialize the lookup table by matching all nucleus IDs to cytosol IDs.
"""
all_nucleus_ids = self.get_unique_ids(self.nucleus_mask)

for nucleus_id in all_nucleus_ids:
self.match_nucleus_id(nucleus_id)

def count_cytosol_occurances(self):
"""
Count the occurrences of each cytosol ID in the lookup table.
"""
cytosol_count = defaultdict(int)

# Count the occurrences of each cytosol value
for cytosol in self._nucleus_lookup_dict.values():
cytosol_count[cytosol] += 1

self.cytosol_count = cytosol_count

def check_for_unassigned_cytosols(self):
"""
Check for unassigned cytosol IDs in the cytosol mask.
"""
all_cytosol_ids = self.get_unique_ids(self.cytosol_mask)

for cytosol_id in all_cytosol_ids:
if cytosol_id not in self._nucleus_lookup_dict.values():
self.cytosol_discard_list.append(cytosol_id)

def identify_multinucleated_cells(self):
"""
Identify and discard multinucleated cells from the lookup table.
"""
for nucleus, cytosol in self._nucleus_lookup_dict.items():
if self.cytosol_count[cytosol] > 1:
self.nuclei_discard_list.append(nucleus)
self.cytosol_discard_list.append(cytosol)

def cleanup_filtering_lists(self):
"""
Cleanup the discard lists by removing duplicate entries.
"""
self.nuclei_discard_list = list(set(self.nuclei_discard_list))
self.cytosol_discard_list = list(set(self.cytosol_discard_list))

def cleanup_lookup_dictionary(self):
"""
Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs.
"""
_cleanup = []
for nucleus_id, cytosol_id in self._nucleus_lookup_dict.items():
if nucleus_id in self.nuclei_discard_list:
Expand All @@ -577,6 +692,21 @@ def cleanup_lookup_dictionary(self):
del self._nucleus_lookup_dict[nucleus]

def generate_lookup_table(self, nucleus_mask, cytosol_mask):
"""
Generate the lookup table by performing all necessary steps.
Parameters
----------
nucleus_mask : numpy.ndarray
The nucleus mask.
cytosol_mask : numpy.ndarray
The cytosol mask.
Returns
-------
dict
The lookup table mapping nucleus IDs to matched cytosol IDs.
"""
self.load_masks(nucleus_mask, cytosol_mask)
self.initialize_lookup_table()
self.count_cytosol_occurances()
Expand All @@ -585,13 +715,29 @@ def generate_lookup_table(self, nucleus_mask, cytosol_mask):
self.cleanup_filtering_lists()
self.cleanup_lookup_dictionary()

self.nucleus_lookup_dict = self._nucleus_lookup_dict #save final result to a new variable name
self.nucleus_lookup_dict = self._nucleus_lookup_dict

return(self.nucleus_lookup_dict)
return self.nucleus_lookup_dict

def filter(self, nucleus_mask, cytosol_mask):

"""
Filter the nucleus and cytosol masks based on the matching results and return the updated masks.
Parameters
----------
nucleus_mask : numpy.ndarray
The nucleus mask.
cytosol_mask : numpy.ndarray
The cytosol mask.
Returns
-------
tuple
A tuple containing the updated nucleus mask and cytosol mask after filtering.
"""
self.load_masks(nucleus_mask, cytosol_mask)
self.generate_lookup_table(nucleus_mask, cytosol_mask)

if self.nucleus_lookup_dict is None:
self.generate_lookup_table(nucleus_mask, cytosol_mask)

return self.update_masks()

0 comments on commit f23b9a0

Please sign in to comment.