diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index ff03b64f..b8d06059 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -418,6 +418,80 @@ def filter(self, input_mask): return self.update_mask(input_mask, self.ids_to_remove) class MatchNucleusCytosolIds(BaseFilter): + """ + Filter class for matching nucleusIDs to their matching cytosol IDs and removing all classes from the given + segmentation masks that do not fullfill the filtering criteria. + + Masks only pass filtering if both a nucleus and a cytosol mask are present and have an overlapping area + larger than the specified threshold. If the threshold is not specified, the default value is set to 0.5. + + Parameters + ---------- + filtering_threshold : float, optional + The threshold for filtering cytosol IDs based on the proportion of overlapping area with the nucleus. Default is 0.5. + downsampling_factor : int, optional + The downsampling factor for the masks. Default is None. + erosion_dilation : bool, optional + Flag indicating whether to perform erosion and dilation on the masks during upscaling. Default is True. + smoothing_kernel_size : int, optional + The size of the smoothing kernel for upscaling. Default is 7. + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. + + Attributes + ---------- + filtering_threshold : float + The threshold for filtering cytosol IDs based on the proportion of overlapping area with the nucleus. + downsample : bool + Flag indicating whether downsampling is enabled. + downsampling_factor : int + The downsampling factor for the masks. + erosion_dilation : bool + Flag indicating whether to perform erosion and dilation on the masks during upscaling. + smoothing_kernel_size : int + The size of the smoothing kernel for upscaling. + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + nuclei_discard_list : list + A list of nucleus IDs to be discarded. + cytosol_discard_list : list + A list of cytosol IDs to be discarded. + nucleus_lookup_dict : dict + A dictionary mapping nucleus IDs to matched cytosol IDs after filtering. + + Methods + ------- + load_masks(nucleus_mask, cytosol_mask) + Load the nucleus and cytosol masks. + update_cytosol_mask(cytosol_mask) + Update the cytosol mask based on the matched nucleus-cytosol pairs. + update_masks() + Update the nucleus and cytosol masks after filtering. + match_nucleus_id(nucleus_id) + Match the given nucleus ID to a cytosol ID based on the overlapping area. + initialize_lookup_table() + Initialize the lookup table by matching all nucleus IDs to cytosol IDs. + count_cytosol_occurances() + Count the occurrences of each cytosol ID in the lookup table. + check_for_unassigned_cytosols() + Check for unassigned cytosol IDs in the cytosol mask. + identify_multinucleated_cells() + Identify and discard multinucleated cells from the lookup table. + cleanup_filtering_lists() + Cleanup the discard lists by removing duplicate entries. + cleanup_lookup_dictionary() + Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs. + generate_lookup_table(nucleus_mask, cytosol_mask) + Generate the lookup table by performing all necessary steps. + filter(nucleus_mask, cytosol_mask) + Filter the nucleus and cytosol masks based on the matching results. + + """ + def __init__( self, filtering_threshold=0.5, @@ -428,31 +502,46 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) + + #set relevant parameters self.filtering_threshold = filtering_threshold - #set up downsampling + #set downsampling parameters if downsampling_factor is not None: - self.downsample = True + self.downsample = True self.downsampling_factor = downsampling_factor self.erosion_dilation = erosion_dilation self.smoothing_kernel_size = smoothing_kernel_size else: self.downsample = False - + + #initialize placeholders for masks self.nucleus_mask = None self.cytosol_mask = None + # initialize datastructures for saving results self._nucleus_lookup_dict = {} self.nuclei_discard_list = [] self.cytosol_discard_list = [] + self.nucleus_lookup_dict = None def load_masks(self, nucleus_mask, cytosol_mask): - #masks are only loaded once (if already loaded nothing happens) + """ + Load the nucleus and cytosol masks into their placeholders. + This function only loads the masks (and downsamples them if necessary) if this has not already been performed. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + """ if self.downsample: if self.nucleus_mask is None: self.nucleus_mask = self.downsample_mask(nucleus_mask) if self.cytosol_mask is None: - self.cytosol_mask = self.downsample_mask(cytosol_mask) + self.cytosol_mask = self.downsample_mask(cytosol_mask) else: if self.nucleus_mask is None: self.nucleus_mask = nucleus_mask @@ -460,22 +549,40 @@ def load_masks(self, nucleus_mask, cytosol_mask): self.cytosol_mask = cytosol_mask def update_cytosol_mask(self, cytosol_mask): - # now we have all the nucleus cytosol pairs we can filter the masks + """ + Update the cytosol mask based on the matched nucleus-cytosol pairs. + + Parameters + ---------- + cytosol_mask : numpy.ndarray + The cytosol mask. + + Returns + ------- + numpy.ndarray + The updated cytosol mask. + """ updated = np.zeros_like(cytosol_mask, dtype=bool) for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items(): - # set the cytosol pixels to the nucleus_id if not previously updated - condition = np.logical_and( - cytosol_mask == cytosol_id, ~updated - ) - cytosol_mask[condition] = nucleus_id - updated = np.logical_or( - updated, condition - ) - return(cytosol_mask) + condition = np.logical_and( + cytosol_mask == cytosol_id, ~updated + ) + cytosol_mask[condition] = nucleus_id + updated = np.logical_or( + updated, condition + ) + return cytosol_mask def update_masks(self): + """ + Update the nucleus and cytosol masks after filtering. + Returns + ------- + tuple + A tuple containing the updated nucleus mask and cytosol mask. + """ nucleus_mask = self.update_mask(self.nucleus_mask, self.nuclei_discard_list) cytosol_mask = self.update_cytosol_mask(self.cytosol_mask) @@ -491,32 +598,23 @@ def match_nucleus_id(self, nucleus_id): Parameters ---------- - nucleus_mask : numpy.ndarray - The nucleus mask. - cytosol_mask : numpy.ndarray - The cytosol mask. nucleus_id : int The nucleus ID to be matched. Returns ------- - int - The matched cytosol ID. + int or None + The matched cytosol ID, or None if no match is found. """ - # get the coordinates of the nucleus nucleus_pixels = np.where(self.nucleus_mask == nucleus_id) - - # check if those indices are not background in the cytosol mask potential_cytosol = self.cytosol_mask[nucleus_pixels] - # if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus if np.all(potential_cytosol != 0): unique_cytosol, counts = np.unique(potential_cytosol, return_counts=True) all_counts = np.sum(counts) cytosol_proportions = counts / all_counts if np.any(cytosol_proportions >= self.filtering_threshold): - # get the cytosol_id with max proportion cytosol_id = unique_cytosol[ np.argmax(cytosol_proportions >= self.filtering_threshold) ] @@ -534,21 +632,29 @@ def match_nucleus_id(self, nucleus_id): return None def initialize_lookup_table(self): + """ + Initialize the lookup table by matching all nucleus IDs to cytosol IDs. + """ all_nucleus_ids = self.get_unique_ids(self.nucleus_mask) for nucleus_id in all_nucleus_ids: self.match_nucleus_id(nucleus_id) def count_cytosol_occurances(self): + """ + Count the occurrences of each cytosol ID in the lookup table. + """ cytosol_count = defaultdict(int) - # Count the occurrences of each cytosol value for cytosol in self._nucleus_lookup_dict.values(): cytosol_count[cytosol] += 1 self.cytosol_count = cytosol_count def check_for_unassigned_cytosols(self): + """ + Check for unassigned cytosol IDs in the cytosol mask. + """ all_cytosol_ids = self.get_unique_ids(self.cytosol_mask) for cytosol_id in all_cytosol_ids: @@ -556,16 +662,25 @@ def check_for_unassigned_cytosols(self): self.cytosol_discard_list.append(cytosol_id) def identify_multinucleated_cells(self): + """ + Identify and discard multinucleated cells from the lookup table. + """ for nucleus, cytosol in self._nucleus_lookup_dict.items(): if self.cytosol_count[cytosol] > 1: self.nuclei_discard_list.append(nucleus) self.cytosol_discard_list.append(cytosol) def cleanup_filtering_lists(self): + """ + Cleanup the discard lists by removing duplicate entries. + """ self.nuclei_discard_list = list(set(self.nuclei_discard_list)) self.cytosol_discard_list = list(set(self.cytosol_discard_list)) def cleanup_lookup_dictionary(self): + """ + Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs. + """ _cleanup = [] for nucleus_id, cytosol_id in self._nucleus_lookup_dict.items(): if nucleus_id in self.nuclei_discard_list: @@ -577,6 +692,21 @@ def cleanup_lookup_dictionary(self): del self._nucleus_lookup_dict[nucleus] def generate_lookup_table(self, nucleus_mask, cytosol_mask): + """ + Generate the lookup table by performing all necessary steps. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + + Returns + ------- + dict + The lookup table mapping nucleus IDs to matched cytosol IDs. + """ self.load_masks(nucleus_mask, cytosol_mask) self.initialize_lookup_table() self.count_cytosol_occurances() @@ -585,13 +715,29 @@ def generate_lookup_table(self, nucleus_mask, cytosol_mask): self.cleanup_filtering_lists() self.cleanup_lookup_dictionary() - self.nucleus_lookup_dict = self._nucleus_lookup_dict #save final result to a new variable name + self.nucleus_lookup_dict = self._nucleus_lookup_dict - return(self.nucleus_lookup_dict) + return self.nucleus_lookup_dict def filter(self, nucleus_mask, cytosol_mask): - + """ + Filter the nucleus and cytosol masks based on the matching results and return the updated masks. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + + Returns + ------- + tuple + A tuple containing the updated nucleus mask and cytosol mask after filtering. + """ self.load_masks(nucleus_mask, cytosol_mask) - self.generate_lookup_table(nucleus_mask, cytosol_mask) + + if self.nucleus_lookup_dict is None: + self.generate_lookup_table(nucleus_mask, cytosol_mask) return self.update_masks() \ No newline at end of file