From 0b62d8ae6f129c35bbc69c31a8aad0c01e9bdc72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:42:30 +0200 Subject: [PATCH] rename functions more consistently --- src/sparcscore/processing/filtering.py | 164 +++++++++++++++---------- 1 file changed, 101 insertions(+), 63 deletions(-) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index c6e31a6b..aa6603ff 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -14,11 +14,11 @@ class BaseFilter(Logable): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def get_unique_ids(self, mask): return np.unique(mask)[1:] # to remove the background - def update_mask(self, mask, ids_to_remove): + def get_updated_mask(self, mask, ids_to_remove): """ Update the given mask by setting the values corresponding to the specified IDs to 0. @@ -36,10 +36,10 @@ def update_mask(self, mask, ids_to_remove): """ return np.where(np.isin(mask, ids_to_remove), 0, mask) - def downsample_mask(self, mask): + def get_downsampled_mask(self, mask): return downsample_img_pxs(mask, N=self.downsampling_factor) - def upscale_mask_basic(self, mask, erosion_dilation=False): + def get_upscaled_mask_basic(self, mask, erosion_dilation=False): mask = mask.repeat(self.downsampling_factor, axis=0).repeat( self.downsampling_factor, axis=1 ) @@ -50,6 +50,7 @@ def upscale_mask_basic(self, mask, erosion_dilation=False): return mask + class SizeFilter(BaseFilter): """ Filter class for removing objects from a mask based on their size. @@ -80,6 +81,15 @@ class SizeFilter(BaseFilter): The directory to save the generated figures. If not provided, the current working directory will be used. confidence_interval : float, optional The confidence interval for calculating the filtering threshold. Default is 0.95. + n_components : int, optional + The number of components in the Gaussian mixture model. Default is 1. + population_to_keep : str, optional + For multipopulation models this parameter determines which population should be kept. Options are "largest", "smallest", "mostcommon", "leastcommon". Default is "mostcommon". + If set to "largest" or "smallest", the model is chosen which has the largest or smallest mean. If set to "mostcommon" or "leastcommon", the model is chosen whose population is least or most common. + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. Examples -------- @@ -102,7 +112,7 @@ def __init__( plot_qc=True, directory=None, confidence_interval=0.95, - n_components=2, + n_components=1, population_to_keep="largest", *args, **kwargs, @@ -125,7 +135,7 @@ def __init__( self.ids_to_remove = None - def plot_gaussian_model( + def _get_gaussian_model_plot( self, counts, means, @@ -213,7 +223,7 @@ def plot_gaussian_model( return fig - def plot_histogram( + def _get_histogram_plot( self, values, label=None, @@ -259,7 +269,40 @@ def plot_histogram( return fig - def calculate_filtering_threshold(self, counts): + def _get_index_population(self, means, weights): + """ + Returns the index of the model that matches the population of cells to be kept. + + Parameters + ---------- + means : numpy.ndarray + An array containing the means of the models. + weights : numpy.ndarray + An array containing the weights of the models. + + Returns + ------- + int + The index of the model that matches the population criteria. + + Notes + ----- + The function determines the index of the model based on the population of cells that should be kept. + The population criteria can be set to "largest", "smallest", "mostcommon", or "leastcommon". + If the population criteria is set to "largest" or "smallest", the index is determined based on the means array. + If the population criteria is set to "mostcommon" or "leastcommon", the index is determined based on the weights array. + """ + if self.population_to_keep == "largest": + idx = np.argmax(means) + elif self.population_to_keep == "smallest": + idx = np.argmin(means) + elif self.population_to_keep == "mostcommon": + idx = np.argmax(weights) + elif self.population_to_keep == "leastcommon": + idx = np.argmin(weights) + return idx + + def _calculate_filtering_threshold(self, counts): """ Calculate the filtering thresholds for the given counts. @@ -306,10 +349,7 @@ def calculate_filtering_threshold(self, counts): weights = gmm.weights_ # get index of the model which matches to the population of cells that should be kept - if self.population_to_keep == "largest": - idx = np.argmax(means) - elif self.population_to_keep == "smallest": - idx = np.argmin(means) + idx = self._get_index_population(means, weights) # calculate the thresholds for the selected model using the given confidence interval mu = means[idx] @@ -324,7 +364,7 @@ def calculate_filtering_threshold(self, counts): threshold = (lower_threshold, upper_threshold) - fig = self.plot_gaussian_model( + fig = self._get_gaussian_model_plot( counts=data, means=means, variances=variances, @@ -343,9 +383,8 @@ def calculate_filtering_threshold(self, counts): self.log( f"Calculated threshold for {self.label} with {self.confidence_interval * 100}% confidence interval: {self.threshold}" ) - return self.threshold - def get_ids_to_remove(self, input_mask): + def _get_ids_to_remove(self, input_mask): """ Get the IDs to remove from the input mask based on the filtering threshold. @@ -371,7 +410,7 @@ def get_ids_to_remove(self, input_mask): counts = np.unique(input_mask, return_counts=True) pixel_counts = counts[1][1:] - fig = self.plot_histogram(pixel_counts, self.label) + fig = self._get_histogram_plot(pixel_counts, self.label) plt.close(fig) if self.plot_qc: @@ -379,7 +418,7 @@ def get_ids_to_remove(self, input_mask): # automatically calculate filtering threshold if not provided if self.filter_threshold is None: - self.filter_threshold = self.calculate_filtering_threshold(pixel_counts) + self._calculate_filtering_threshold(pixel_counts) ids_remove = [] _ids = counts[0][1:][np.where(pixel_counts < self.filter_threshold[0])] @@ -414,9 +453,10 @@ def filter(self, input_mask): """ if self.ids_to_remove is None: - self.get_ids_to_remove(input_mask) + self._get_ids_to_remove(input_mask) return self.update_mask(input_mask, self.ids_to_remove) + class MatchNucleusCytosolIds(BaseFilter): """ Filter class for matching nucleusIDs to their matching cytosol IDs and removing all classes from the given @@ -503,29 +543,29 @@ def __init__( ): super().__init__(*args, **kwargs) - #set relevant parameters + # set relevant parameters self.filtering_threshold = filtering_threshold - #set downsampling parameters + # set downsampling parameters if downsampling_factor is not None: - self.downsample = True + self.downsample = True self.downsampling_factor = downsampling_factor self.erosion_dilation = erosion_dilation self.smoothing_kernel_size = smoothing_kernel_size else: self.downsample = False - - #initialize placeholders for masks + + # initialize placeholders for masks self.nucleus_mask = None self.cytosol_mask = None - + # initialize datastructures for saving results self._nucleus_lookup_dict = {} self.nuclei_discard_list = [] self.cytosol_discard_list = [] self.nucleus_lookup_dict = None - def load_masks(self, nucleus_mask, cytosol_mask): + def _load_masks(self, nucleus_mask, cytosol_mask): """ Load the nucleus and cytosol masks into their placeholders. This function only loads the masks (and downsamples them if necessary) if this has not already been performed. @@ -547,8 +587,8 @@ def load_masks(self, nucleus_mask, cytosol_mask): self.nucleus_mask = nucleus_mask if self.cytosol_mask is None: self.cytosol_mask = cytosol_mask - - def update_cytosol_mask(self, cytosol_mask): + + def _get_updated_cytosol_mask(self, cytosol_mask): """ Update the cytosol mask based on the matched nucleus-cytosol pairs. @@ -565,16 +605,12 @@ def update_cytosol_mask(self, cytosol_mask): updated = np.zeros_like(cytosol_mask, dtype=bool) for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items(): - condition = np.logical_and( - cytosol_mask == cytosol_id, ~updated - ) + condition = np.logical_and(cytosol_mask == cytosol_id, ~updated) cytosol_mask[condition] = nucleus_id - updated = np.logical_or( - updated, condition - ) + updated = np.logical_or(updated, condition) return cytosol_mask - - def update_masks(self): + + def _get_updated_masks(self): """ Update the nucleus and cytosol masks after filtering. @@ -588,11 +624,13 @@ def update_masks(self): if self.downsample: nucleus_mask = self.upscale_mask_basic(nucleus_mask, self.erosion_dilation) - cytosol_mask = self.upscale_mask_basic(self.cytosol_mask, self.erosion_dilation) - + cytosol_mask = self.upscale_mask_basic( + self.cytosol_mask, self.erosion_dilation + ) + return nucleus_mask, cytosol_mask - - def match_nucleus_id(self, nucleus_id): + + def _match_nucleus_id(self, nucleus_id): """ Match the given nucleus ID to a cytosol ID based on the overlapping area. @@ -630,8 +668,8 @@ def match_nucleus_id(self, nucleus_id): else: self.nuclei_discard_list.append(nucleus_id) return None - - def initialize_lookup_table(self): + + def _initialize_lookup_table(self): """ Initialize the lookup table by matching all nucleus IDs to cytosol IDs. """ @@ -639,8 +677,8 @@ def initialize_lookup_table(self): for nucleus_id in all_nucleus_ids: self.match_nucleus_id(nucleus_id) - - def count_cytosol_occurances(self): + + def _count_cytosol_occurances(self): """ Count the occurrences of each cytosol ID in the lookup table. """ @@ -648,10 +686,10 @@ def count_cytosol_occurances(self): for cytosol in self._nucleus_lookup_dict.values(): cytosol_count[cytosol] += 1 - + self.cytosol_count = cytosol_count - - def check_for_unassigned_cytosols(self): + + def _check_for_unassigned_cytosols(self): """ Check for unassigned cytosol IDs in the cytosol mask. """ @@ -660,8 +698,8 @@ def check_for_unassigned_cytosols(self): for cytosol_id in all_cytosol_ids: if cytosol_id not in self._nucleus_lookup_dict.values(): self.cytosol_discard_list.append(cytosol_id) - - def identify_multinucleated_cells(self): + + def _identify_multinucleated_cells(self): """ Identify and discard multinucleated cells from the lookup table. """ @@ -670,14 +708,14 @@ def identify_multinucleated_cells(self): self.nuclei_discard_list.append(nucleus) self.cytosol_discard_list.append(cytosol) - def cleanup_filtering_lists(self): + def _cleanup_filtering_lists(self): """ Cleanup the discard lists by removing duplicate entries. """ self.nuclei_discard_list = list(set(self.nuclei_discard_list)) self.cytosol_discard_list = list(set(self.cytosol_discard_list)) - def cleanup_lookup_dictionary(self): + def _cleanup_lookup_dictionary(self): """ Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs. """ @@ -687,13 +725,13 @@ def cleanup_lookup_dictionary(self): _cleanup.append(nucleus_id) if cytosol_id in self.cytosol_discard_list: _cleanup.append(nucleus_id) - - #ensure we have no duplicate entries + + # ensure we have no duplicate entries _cleanup = list(set(_cleanup)) for nucleus in _cleanup: del self._nucleus_lookup_dict[nucleus] - def generate_lookup_table(self, nucleus_mask, cytosol_mask): + def get_lookup_table(self, nucleus_mask, cytosol_mask): """ Generate the lookup table by performing all necessary steps. @@ -710,12 +748,12 @@ def generate_lookup_table(self, nucleus_mask, cytosol_mask): The lookup table mapping nucleus IDs to matched cytosol IDs. """ self.load_masks(nucleus_mask, cytosol_mask) - self.initialize_lookup_table() - self.count_cytosol_occurances() - self.check_for_unassigned_cytosols() - self.identify_multinucleated_cells() - self.cleanup_filtering_lists() - self.cleanup_lookup_dictionary() + self._initialize_lookup_table() + self._count_cytosol_occurances() + self._check_for_unassigned_cytosols() + self._identify_multinucleated_cells() + self._cleanup_filtering_lists() + self._cleanup_lookup_dictionary() self.nucleus_lookup_dict = self._nucleus_lookup_dict @@ -737,9 +775,9 @@ def filter(self, nucleus_mask, cytosol_mask): tuple A tuple containing the updated nucleus mask and cytosol mask after filtering. """ - self.load_masks(nucleus_mask, cytosol_mask) + self._load_masks(nucleus_mask, cytosol_mask) if self.nucleus_lookup_dict is None: - self.generate_lookup_table(nucleus_mask, cytosol_mask) - - return self.update_masks() \ No newline at end of file + self.get_lookup_table(nucleus_mask, cytosol_mask) + + return self._get_updated_masks()