From 69fc10473b8517b4d711a63b284168101e73f7cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:44:45 +0200 Subject: [PATCH 01/18] add filter to match cytosol and nucleus ids --- src/sparcscore/processing/filtering.py | 299 +++++++++++++++++++++---- 1 file changed, 254 insertions(+), 45 deletions(-) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index 0450c189..9add1cee 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -4,22 +4,66 @@ from scipy.stats import norm import os +from skimage.morphology import disk, dilation, erosion +from collections import defaultdict + from sparcscore.pipeline.base import Logable +from sparcscore.processing.preprocessing import downsample_img_pxs + + +class BaseFilter(Logable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_unique_ids(self, mask): + return np.unique(mask)[1:] # to remove the background + + def update_mask(self, mask, ids_to_remove): + """ + Update the given mask by setting the values corresponding to the specified IDs to 0. + + Parameters + ---------- + mask : numpy.ndarray + The mask to be updated. + ids_to_remove : numpy.ndarray + The IDs to be removed from the mask. + + Returns + ------- + numpy.ndarray + The updated mask with the specified IDs set to 0. + """ + return np.where(np.isin(mask, ids_to_remove), 0, mask) + + def downsample_mask(self, mask): + return downsample_img_pxs(mask, N=self.downsampling_factor) + + def upscale_mask_basic(self, mask, erosion_dilation=False): + mask = mask.repeat(self.downsampling_factor, axis=0).repeat( + self.downsampling_factor, axis=1 + ) -class SizeFilter(Logable): + if erosion_dilation: + mask = erosion(mask, footprint=disk(self.smoothing_kernel_size)) + mask = dilation(mask, footprint=disk(self.smoothing_kernel_size)) + + return mask + +class SizeFilter(BaseFilter): """ Filter class for removing objects from a mask based on their size. - This class provides methods to remove objects from a segmentation mask based on their size. + This class provides methods to remove objects from a segmentation mask based on their size. If specified the objects are filtered using a threshold range passed by the user. Otherwise, - this threshold range will be automatically calculated. - + this threshold range will be automatically calculated. + To automatically calculate the threshold range, a gaussian mixture model will be fitted to the data. Per default, the number of components is set to 2, as it is assumed that the objects in the mask can be divided into two groups: small and large objects. The small objects constitute segmentation artefacts (partial masks that are frequently generated by segmentation models like e.g. cellpose) while the large objects represent the actual cell masks of interest. Using the fitted model, the filtering thresholds are calculated - to remove all cells that fall outside of the given confidence interval. + to remove all cells that fall outside of the given confidence interval. Parameters ---------- @@ -28,7 +72,7 @@ class SizeFilter(Logable): label : str, optional The label of the mask. Default is "segmask". log : bool, optional - Whether to take the logarithm of the size of the objects before fitting the normal distribution. Default is True. + Whether to take the logarithm of the size of the objects before fitting the normal distribution. Default is True. By enabling this option, the filter will better be able to distinguish between small and large objects. plot_qc : bool, optional Whether to plot quality control figures. Default is True. @@ -47,9 +91,9 @@ class SizeFilter(Logable): >>> ids_to_remove = filter.get_ids_to_remove(input_mask) >>> # Update the mask by removing the identified object IDs >>> updated_mask = filter.update_mask(input_mask, ids_to_remove) - + """ - + def __init__( self, filter_threshold=None, @@ -64,7 +108,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - + self.log_values = log self.plot_qc = plot_qc self.label = label @@ -72,8 +116,8 @@ def __init__( self.confidence_interval = confidence_interval self.n_components = n_components self.population_to_keep = population_to_keep - - #if no directory is provided, use the current working directory + + # if no directory is provided, use the current working directory if directory is not None: self.directory = directory else: @@ -81,7 +125,8 @@ def __init__( self.ids_to_remove = None - def plot_gaussian_model(self, + def plot_gaussian_model( + self, counts, means, variances, @@ -90,7 +135,7 @@ def plot_gaussian_model(self, bins=30, figsize=(5, 5), alpha=0.5, - save_figure=True + save_figure=True, ): """ Plot a histogram of the provided data with fitted Gaussian distributions. @@ -167,7 +212,7 @@ def plot_gaussian_model(self, fig.savefig(os.path.join(self.directory, f"{self.label}_bimodal_model.png")) return fig - + def plot_histogram( self, values, @@ -266,12 +311,12 @@ def calculate_filtering_threshold(self, counts): elif self.population_to_keep == "smallest": idx = np.argmin(means) - #calculate the thresholds for the selected model using the given confidence interval - mu = means[idx] - sigma = np.sqrt(variances[idx]) + # calculate the thresholds for the selected model using the given confidence interval + mu = means[idx] + sigma = np.sqrt(variances[idx]) - percent = (1 - self.confidence_interval) - lower = percent / 2 + percent = 1 - self.confidence_interval + lower = percent / 2 upper = 1 - percent / 2 lower_threshold = mu + sigma * norm.ppf(lower) @@ -279,11 +324,12 @@ def calculate_filtering_threshold(self, counts): threshold = (lower_threshold, upper_threshold) - fig = self.plot_gaussian_model(counts = data, - means = means, - variances = variances, - weights = weights, - threshold = threshold, + fig = self.plot_gaussian_model( + counts=data, + means=means, + variances=variances, + weights=weights, + threshold=threshold, ) if self.plot_qc: @@ -293,8 +339,10 @@ def calculate_filtering_threshold(self, counts): self.threshold = np.exp(threshold) else: self.threshold = threshold - - self.log(f"Calculated threshold for {self.label} with {self.confidence_interval * 100}% confidence interval: {self.threshold}") + + self.log( + f"Calculated threshold for {self.label} with {self.confidence_interval * 100}% confidence interval: {self.threshold}" + ) return self.threshold def get_ids_to_remove(self, input_mask): @@ -350,24 +398,6 @@ def get_ids_to_remove(self, input_mask): self.ids_to_remove = ids_remove - def update_mask(self, mask, ids_to_remove): - """ - Update the given mask by setting the values corresponding to the specified IDs to 0. - - Parameters - ---------- - mask : numpy.ndarray - The mask to be updated. - ids_to_remove : numpy.ndarray - The IDs to be removed from the mask. - - Returns - ------- - numpy.ndarray - The updated mask with the specified IDs set to 0. - """ - return np.where(np.isin(mask, ids_to_remove), 0, mask) - def filter(self, input_mask): """ Filter the input mask based on the filtering threshold. @@ -375,7 +405,7 @@ def filter(self, input_mask): Parameters ---------- input_mask : ndarray - The input mask to be filtered. + The input mask to be filtered. Expected shape is (X, Y) Returns ------- @@ -386,3 +416,182 @@ def filter(self, input_mask): if self.ids_to_remove is None: self.get_ids_to_remove(input_mask) return self.update_mask(input_mask, self.ids_to_remove) + +class MatchNucleusCytosolIds(BaseFilter): + def __init__( + self, + filtering_threshold=0.5, + downsampling_factor=None, + erosion_dilation=True, + smoothing_kernel_size=7, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.filtering_threshold = filtering_threshold + + #set up downsampling + if downsampling_factor is not None: + self.downsample = True + self.downsampling_factor = downsampling_factor + self.erosion_dilation = erosion_dilation + self.smoothing_kernel_size = smoothing_kernel_size + else: + self.downsample = False + + self.nucleus_mask = None + self.cytosol_mask = None + + self._nucleus_lookup_dict = {} + self.nuclei_discard_list = [] + self.cytosol_discard_list = [] + + def load_masks(self, nucleus_mask, cytosol_mask): + #masks are only loaded once (if already loaded nothing happens) + if self.downsample: + if self.nucleus_mask is None: + self.nucleus_mask = self.downsample_mask(nucleus_mask) + if self.cytosol_mask is None: + self.cytosol_mask = self.downsample_mask(cytosol_mask) + else: + if self.nucleus_mask is None: + self.nucleus_mask = nucleus_mask + if self.cytosol_mask is None: + self.cytosol_mask = cytosol_mask + + def update_cytosol_mask(self, cytosol_mask): + # now we have all the nucleus cytosol pairs we can filter the masks + updated = np.zeros_like(cytosol_mask, dtype=bool) + + for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items(): + # set the cytosol pixels to the nucleus_id if not previously updated + condition = np.logical_and( + cytosol_mask == cytosol_id, ~updated + ) + cytosol_mask[condition] = nucleus_id + updated = np.logical_or( + updated, condition + ) + return(cytosol_mask) + + def update_masks(self): + + nucleus_mask = self.update_mask(self.nucleus_mask, self.nuclei_discard_list) + cytosol_mask = self.update_mask(self.cytosol_mask) + + if self.downsample: + nucleus_mask = self.upscale_mask_basic(nucleus_mask, self.erosion_dilation) + cytosol_mask = self.upscale_mask_basic(self.cytosol_mask, self.erosion_dilation) + + return nucleus_mask, cytosol_mask + + def match_nucleus_id(self, nucleus_id): + """ + Match the given nucleus ID to a cytosol ID based on the overlapping area. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + nucleus_id : int + The nucleus ID to be matched. + + Returns + ------- + int + The matched cytosol ID. + """ + # get the coordinates of the nucleus + nucleus_pixels = np.where(self.nucleus_mask == nucleus_id) + + # check if those indices are not background in the cytosol mask + potential_cytosol = self.cytosol_mask[nucleus_pixels] + + # if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus + if np.all(potential_cytosol != 0): + unique_cytosol, counts = np.unique(potential_cytosol, return_counts=True) + all_counts = np.sum(counts) + cytosol_proportions = counts / all_counts + + if np.any(cytosol_proportions >= self.filtering_threshold): + # get the cytosol_id with max proportion + cytosol_id = unique_cytosol[ + np.argmax(cytosol_proportions >= self.filtering_threshold) + ] + if cytosol_id != 0: + self._nucleus_lookup_dict[nucleus_id] = cytosol_id + return cytosol_id + else: + self.nuclei_discard_list.append(nucleus_id) + return None + else: + self.nuclei_discard_list.append(nucleus_id) + return None + else: + self.nuclei_discard_list.append(nucleus_id) + return None + + def initialize_lookup_table(self): + all_nucleus_ids = self.get_unique_ids(self.nucleus_mask) + + for nucleus_id in all_nucleus_ids: + self.match_nucleus_id(nucleus_id) + + def count_cytosol_occurances(self): + cytosol_count = defaultdict(int) + + # Count the occurrences of each cytosol value + for cytosol in self._nucleus_lookup_dict.values(): + cytosol_count[cytosol] += 1 + + self.cytosol_counts = cytosol_count + + def check_for_unassigned_cytosols(self): + all_cytosol_ids = self.get_unique_ids(self.cytosol_mask) + + for cytosol_id in all_cytosol_ids: + if cytosol_id not in self._nucleus_lookup_dict.values(): + self.cytosol_discard_list.append(cytosol_id) + + def identify_multinucleated_cells(self): + for nucleus, cytosol in self._nucleus_lookup_dict.items(): + if self.cytosol_count[cytosol] > 1: + self.nuclei_discard_list.append(nucleus) + self.cytosol_discard_list.append(cytosol) + + def cleanup_filtering_lists(self): + self.nuclei_discard_list = list(set(self.nuclei_discard_list)) + self.cytosol_discard_list = list(set(self.cytosol_discard_list)) + + def cleanup_lookup_dictionary(self): + _cleanup = [] + for nucleus_id, cytosol_id in self._nucleus_lookup_dict.items(): + if nucleus_id in self.nuclei_discard_list: + _cleanup.append(nucleus_id) + if cytosol_id in self.cytosol_discard_list: + _cleanup.append(nucleus_id) + + for nucleus in _cleanup: + del self._nucleus_lookup_dict[nucleus] + + def generate_lookup_table(self, nucleus_mask, cytosol_mask): + self.load_masks(nucleus_mask, cytosol_mask) + self.initialize_lookup_table() + self.count_cytosol_occurances() + self.check_for_unassigned_cytosols() + self.identify_multinucleated_cells() + self.cleanup_filtering_lists() + self.cleanup_lookup_dictionary() + + self.nucleus_lookup_dict = self._nucleus_lookup_dict #save final result to a new variable name + + return(self.nucleus_lookup_dict) + + def filter(self, nucleus_mask, cytosol_mask): + + self.load_masks(nucleus_mask, cytosol_mask) + self.generate_lookup_table(nucleus_mask, cytosol_mask) + + return self.update_masks() \ No newline at end of file From 396e97c95751892691c87316558d731ee6344fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:45:00 +0200 Subject: [PATCH 02/18] update segmentation workflow to utilize prewritten filter --- src/sparcscore/pipeline/workflows.py | 287 +-------------------------- 1 file changed, 8 insertions(+), 279 deletions(-) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index b15c7ca2..0159fb1a 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -5,7 +5,7 @@ MultithreadedSegmentation, ) from sparcscore.processing.preprocessing import percentile_normalization, downsample_img -from sparcscore.processing.filtering import SizeFilter +from sparcscore.processing.filtering import SizeFilter, MatchNucleusCytosolIds from sparcscore.processing.utils import visualize_class from sparcscore.processing.segmentation import ( segment_local_threshold, @@ -940,288 +940,20 @@ def cellpose_segmentation(self, input_image): ) masks_cytosol = filter_cytosol.filter(masks_cytosol) + ###################### + ### Perform Filtering match cytosol and nucleus IDs if applicable + ###################### + if not self.filter_status: self.log( "No filtering performed. Cytosol and Nucleus IDs in the two masks do not match. Before proceeding with extraction an additional filtering step needs to be performed" ) else: - ########################## - ### Perform Cell Filtering - ########################## - - # log start time of cell filtering to track - timing_info = [] - - start = time.time() - timing_info.append( - ("start_time", "Time when started the segmentation run", start) - ) - - all_nucleus_ids = np.unique(masks_nucleus)[1:] - nucleus_cytosol_pairs = {} - - self.log(f"Number of nuclei to filter: {len(all_nucleus_ids)}") - - ### STEP 1: filter cells based on having a matching cytosol mask - current_time = time.time() - timing_info.append( - ( - "start_time", - "Time when starting filtering cells (for nucleus_id in all_nucleus_ids) = STEP 1", - current_time, - ) - ) - - for nucleus_id in all_nucleus_ids: - ### STEP 1.1: lookup which image pixels belong to the nucleus - time_in_the_loop = time.time() - - # get the nucleus and set the background to 0 and the nucleus to 1 - nucleus = masks_nucleus == nucleus_id - - # now get the coordinates of the nucleus - nucleus_pixels = np.nonzero(nucleus) - - timing_info.append( - ( - "STEP 1.1", - f"Time required for getting nucleus pixels in seconds for nucleus {nucleus_id}", - time.time() - time_in_the_loop, - ) - ) - - ### Step 1.2: get the cytosol ids in the nucleus area - time_in_the_loop = time.time() - - # check if those indices are not background in the cytosol mask - potential_cytosol = masks_cytosol[nucleus_pixels] - - timing_info.append( - ( - "STEP 1.2", - f"Time required for getting potential cytosol pixels in seconds for nucleus {nucleus_id}", - time.time() - time_in_the_loop, - ) - ) - - if np.all(potential_cytosol != 0): - time_in_the_loop = time.time() - - unique_cytosol, counts = np.unique( - potential_cytosol, return_counts=True - ) - all_counts = np.sum(counts) - cytosol_proportions = counts / all_counts - - timing_info.append( - ( - "STEP 1.3", - f"Time required for getting unique cytosol pixels and calculating their proportions in seconds for nucleus {nucleus_id}", - time.time() - time_in_the_loop, - ) - ) - - if np.any( - cytosol_proportions >= self.config["filtering_threshold"] - ): - time_in_the_loop = time.time() - - # get the cytosol_id with max proportion - cytosol_id = unique_cytosol[ - np.argmax( - cytosol_proportions - >= self.config["filtering_threshold"] - ) - ] - nucleus_cytosol_pairs[nucleus_id] = cytosol_id - else: - nucleus_cytosol_pairs[nucleus_id] = 0 - - timing_info.append( - ( - "STEP 1.4", - f"Time required for getting cytosol_id with max proportion in seconds for nucleus {nucleus_id}", - time.time() - time_in_the_loop, - ) - ) - - timing_info.append( - ( - "STEP 1", - "Time required for filtering cells (for nucleus_id in all_nucleus_ids) in seconds", - time.time() - current_time, - ) - ) - - ####################################################### - ### STEP 2: count the occurrences of each cytosol value - ####################################################### - new_time = time.time() - timing_info.append( - ( - "start_time", - "Time when started counting the occurences of each cytosol id = STEP 2", - new_time, - ) - ) - - # check if there are any cytosol masks that are assigned to multiple nuclei - cytosol_count = defaultdict(int) - - # Count the occurrences of each cytosol value - for cytosol in nucleus_cytosol_pairs.values(): - cytosol_count[cytosol] += 1 - - timing_info.append( - ( - "STEP 2", - "Time required for counting the occurences of each cytosol id in seconds", - time.time() - new_time, - ) - ) - ####################################################### - ### STEP 3: filter cytosol ids that are assigned to more than one nucleus - ####################################################### - - new_time = time.time() - timing_info.append( - ( - "start_time", - "Time when started finding cytosol ids assigned to more than one nucleus = STEP 3", - new_time, - ) - ) - - # Find cytosol values assigned to more than one nucleus - for nucleus, cytosol in nucleus_cytosol_pairs.items(): - if cytosol_count[cytosol] > 1: - nucleus_cytosol_pairs[nucleus] = 0 - - timing_info.append( - ( - "STEP 3", - "Time required for filtering cytosol ids that are assigned to more than one nucleus in seconds", - time.time() - new_time, - ) - ) - - ####################################################### - ### STEP 4: filter cytosol masks that are not in the lookup table - ####################################################### - - new_time = time.time() - timing_info.append( - ( - "start_time", - "Time when started filtering cytosol masks that are not in the lookup table = STEP 4", - new_time, - ) - ) - - # get unique cytosol ids that are not in the lookup table - all_cytosol_ids = set(np.unique(masks_cytosol)) - all_cytosol_ids.discard(0) - used_cytosol_ids = set(nucleus_cytosol_pairs.values()) - not_used_cytosol_ids = all_cytosol_ids - used_cytosol_ids - - # set all cytosol ids that are not present in lookup table to 0 in the cytosol mask - ###speedup of 40X approximately in a small test case with an array of 10000x10000 and 400 cytosol ids to remove - # masks_cytosol = np.where(np.isin(masks_cytosol, not_used_cytosol_ids), 0, masks_cytosol) - for cytosol_id in not_used_cytosol_ids: - masks_cytosol[masks_cytosol == cytosol_id] = 0 - - timing_info.append( - ( - "STEP 4", - "Time required for filtering cytosol masks that are not in the lookup table in seconds", - time.time() - new_time, - ) - ) - - ### STEP 5: filter nucleus masks that are not in the lookup table - new_time = time.time() - timing_info.append( - ( - "start_time", - "Time when started filtering nucleus masks that are not in the lookup table = STEP 5", - new_time, - ) - ) - - # get unique nucleus ids that are not in the lookup table - all_nucleus_ids = set(np.unique(masks_nucleus)) - all_nucleus_ids.discard(0) - used_nucleus_ids = set(nucleus_cytosol_pairs.keys()) - not_used_nucleus_ids = all_nucleus_ids - used_nucleus_ids - - # set all nucleus ids that are not present in lookup table to 0 in the nucleus mask - ###speedup of 40X approximately in a small test case with an array of 10000x10000 and 400 cytosol ids to remove - # masks_nucleus = np.where(np.isin(masks_nucleus, not_used_nucleus_ids), 0, masks_nucleus) - for nucleus_id in not_used_nucleus_ids: - masks_nucleus[masks_nucleus == nucleus_id] = 0 - - timing_info.append( - ( - "STEP 5", - "Time required for filtering nucleus masks that are not in the lookup table in seconds", - time.time() - new_time, - ) - ) - - ################################################################# - ### STEP 6: filter cytosol masks that are not in the lookup table - ################################################################# - - new_time = time.time() - timing_info.append(("Time when started updating masks = STEP 6", new_time)) - - # now we have all the nucleus cytosol pairs we can filter the masks - updated_cytosol_mask = np.zeros_like(masks_cytosol, dtype=bool) - for nucleus_id, cytosol_id in nucleus_cytosol_pairs.items(): - if cytosol_id == 0: - masks_nucleus[masks_nucleus == nucleus_id] = ( - 0 # set the nucleus to 0 - ) - else: - # set the cytosol pixels to the nucleus_id if not previously updated - condition = np.logical_and( - masks_cytosol == cytosol_id, ~updated_cytosol_mask - ) - masks_cytosol[condition] = nucleus_id - updated_cytosol_mask = np.logical_or( - updated_cytosol_mask, condition - ) - - timing_info.append( - ( - "STEP 6", - "Time required for filtering cytosol masks that are not in the lookup table in seconds", - time.time() - new_time, - ) - ) - end = time.time() - - timing_info.append( - ( - "All STEPS", - "Time required for filtering generated masks in seconds", - end - start, - ) - ) - self.log( - f"Time required for filtering generated masks in seconds: {end - start}" - ) - - # generate a dataframe with the time logging information and write out to file - df_timing = pd.DataFrame( - timing_info, columns=["Step", "description", "Time (s)"] - ) - df_timing.to_csv( - f"{self.project_location}/segmentation/timing_info_{self.identifier}.csv", - index=False, - ) + # perform filtering to remove cytosols which do not have a corresponding nucleus + filter = MatchNucleusCytosolIds(filtering_threshold = self.config["filtering_threshold"]) + masks_nucleus, masks_cytosol = filter.filter(masks_nucleus, masks_cytosol) if self.debug: # plot nucleus and cytosol masks before and after filtering @@ -1241,9 +973,6 @@ def cellpose_segmentation(self, input_image): fig.show() del fig # delete figure after showing to free up memory again - # cleanup memory by deleting no longer required variables - del updated_cytosol_mask, all_nucleus_ids, used_nucleus_ids - # first when the masks are finalized save them to the maps self.maps["nucleus_segmentation"] = masks_nucleus.reshape( masks_nucleus.shape[1:] From 8c6aa0276db2b5d7a344a5da18198554699d859a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:45:13 +0200 Subject: [PATCH 03/18] update filtering workflow to utilize seperate filtering function --- .../pipeline/filtering_workflows.py | 111 ++++-------------- 1 file changed, 20 insertions(+), 91 deletions(-) diff --git a/src/sparcscore/pipeline/filtering_workflows.py b/src/sparcscore/pipeline/filtering_workflows.py index bc455fba..6bb39cfc 100644 --- a/src/sparcscore/pipeline/filtering_workflows.py +++ b/src/sparcscore/pipeline/filtering_workflows.py @@ -4,11 +4,9 @@ ) import numpy as np -from tqdm.auto import tqdm -import shutil -from collections import defaultdict -from sparcscore.processing.preprocessing import downsample_img_pxs +from sparcscore.processing.filtering import MatchNucleusCytosolIds + class BaseFiltering(SegmentationFilter): def __init__(self, *args, **kwargs): @@ -26,105 +24,36 @@ class filtering_match_nucleus_to_cytosol(BaseFiltering): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def match_nucleus_id_to_cytosol( - self, nucleus_mask, cytosol_mask, return_ids_to_discard=False - ): - all_nucleus_ids = self.get_unique_ids(nucleus_mask) - all_cytosol_ids = self.get_unique_ids(cytosol_mask) - - nucleus_cytosol_pairs = {} - nuclei_ids_to_discard = [] - - for nucleus_id in tqdm(all_nucleus_ids): - # get the nucleus and set the background to 0 and the nucleus to 1 - nucleus = nucleus_mask == nucleus_id - - # now get the coordinates of the nucleus - nucleus_pixels = np.nonzero(nucleus) - - # check if those indices are not background in the cytosol mask - potential_cytosol = cytosol_mask[nucleus_pixels] - - # if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus - if np.all(potential_cytosol != 0): - unique_cytosol, counts = np.unique( - potential_cytosol, return_counts=True - ) - all_counts = np.sum(counts) - cytosol_proportions = counts / all_counts - - if np.any(cytosol_proportions >= self.config["filtering_threshold"]): - # get the cytosol_id with max proportion - cytosol_id = unique_cytosol[ - np.argmax( - cytosol_proportions >= self.config["filtering_threshold"] - ) - ] - nucleus_cytosol_pairs[nucleus_id] = cytosol_id - else: - # no cytosol found with sufficient quality to call so discard nucleus - nuclei_ids_to_discard.append(nucleus_id) - - else: - # discard nucleus as no matching cytosol found - nuclei_ids_to_discard.append(nucleus_id) - - # check to ensure that only one nucleus_id is assigned to each cytosol_id - cytosol_count = defaultdict(int) - - # Count the occurrences of each cytosol value - for cytosol in nucleus_cytosol_pairs.values(): - cytosol_count[cytosol] += 1 - - # Find cytosol values assigned to more than one nucleus and remove from dictionary - multi_nucleated_nulceus_ids = [] + self.filter_threshold = self.config["filter_threshold"] - for nucleus, cytosol in nucleus_cytosol_pairs.items(): - if cytosol_count[cytosol] > 1: - multi_nucleated_nulceus_ids.append(nucleus) - - # update list of all nuclei used - nuclei_ids_to_discard.append(multi_nucleated_nulceus_ids) - - # remove entries from dictionary - # this needs to be put into a seperate loop because otherwise the dictionary size changes during loop and this throws an error - for nucleus in multi_nucleated_nulceus_ids: - del nucleus_cytosol_pairs[nucleus] - - # get all cytosol_ids that need to be discarded - used_cytosol_ids = set(nucleus_cytosol_pairs.values()) - not_used_cytosol_ids = set(all_cytosol_ids) - used_cytosol_ids - not_used_cytosol_ids = list(not_used_cytosol_ids) - - if return_ids_to_discard: - return (nucleus_cytosol_pairs, nuclei_ids_to_discard, not_used_cytosol_ids) + # allow for optional downsampling to improve computation time + if "downsampling_factor" in self.config.keys(): + self.N = self.config["downsampling_factor"] + self.kernel_size = self.config["downsampling_smoothing_kernel_size"] + self.erosion_dilation = self.config["downsampling_erosion_dilation"] else: - return nucleus_cytosol_pairs + self.N = None + self.kernel_size = None + self.erosion_dilation = False def process(self, input_masks): if isinstance(input_masks, str): input_masks = self.read_input_masks(input_masks) - # allow for optional downsampling to improve computation time - if "downsampling_factor" in self.config.keys(): - N = self.config["downsampling_factor"] - # use a less precise but faster downsampling method that preserves integer values - input_masks = downsample_img_pxs(input_masks, N=N) - - # get input masks - nucleus_mask = input_masks[0, :, :] - cytosol_mask = input_masks[1, :, :] - - nucleus_cytosol_pairs = self.match_nucleus_id_to_cytosol( - nucleus_mask, cytosol_mask + # perform filtering + filter = MatchNucleusCytosolIds( + filter_config=self.filter_threshold, + downsample_factor=self.N, + smoothing_kernel_size=self.kernel_size, + erosion_dilation=self.erosion_dilation, + ) + nucleus_cytosol_pairs = filter.generate_lookup_table( + input_masks[0], input_masks[1] ) # save results self.save_classes(classes=nucleus_cytosol_pairs) - # cleanup TEMP directories if not done during individual tile runs - if hasattr(self, "TEMP_DIR_NAME"): - shutil.rmtree(self.TEMP_DIR_NAME) class multithreaded_filtering_match_nucleus_to_cytosol(TiledSegmentationFilter): method = filtering_match_nucleus_to_cytosol From 7ad409344700458ce66b155a55127207bd87afeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:47:14 +0200 Subject: [PATCH 04/18] fix inconsistent variable naming --- src/sparcscore/processing/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index 9add1cee..02caed3a 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -546,7 +546,7 @@ def count_cytosol_occurances(self): for cytosol in self._nucleus_lookup_dict.values(): cytosol_count[cytosol] += 1 - self.cytosol_counts = cytosol_count + self.cytosol_count = cytosol_count def check_for_unassigned_cytosols(self): all_cytosol_ids = self.get_unique_ids(self.cytosol_mask) From 3da72f355fcc9c4055fdcdc21dba6654f4f1cc45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:49:27 +0200 Subject: [PATCH 05/18] ensure cytosol mask is updated with the correct function --- src/sparcscore/processing/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index 02caed3a..ff03b64f 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -477,7 +477,7 @@ def update_cytosol_mask(self, cytosol_mask): def update_masks(self): nucleus_mask = self.update_mask(self.nucleus_mask, self.nuclei_discard_list) - cytosol_mask = self.update_mask(self.cytosol_mask) + cytosol_mask = self.update_cytosol_mask(self.cytosol_mask) if self.downsample: nucleus_mask = self.upscale_mask_basic(nucleus_mask, self.erosion_dilation) From f23b9a07bd91f89ffce2478dc464b8047f45399f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:57:57 +0200 Subject: [PATCH 06/18] add docstrings to filtering method --- src/sparcscore/processing/filtering.py | 208 +++++++++++++++++++++---- 1 file changed, 177 insertions(+), 31 deletions(-) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index ff03b64f..b8d06059 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -418,6 +418,80 @@ def filter(self, input_mask): return self.update_mask(input_mask, self.ids_to_remove) class MatchNucleusCytosolIds(BaseFilter): + """ + Filter class for matching nucleusIDs to their matching cytosol IDs and removing all classes from the given + segmentation masks that do not fullfill the filtering criteria. + + Masks only pass filtering if both a nucleus and a cytosol mask are present and have an overlapping area + larger than the specified threshold. If the threshold is not specified, the default value is set to 0.5. + + Parameters + ---------- + filtering_threshold : float, optional + The threshold for filtering cytosol IDs based on the proportion of overlapping area with the nucleus. Default is 0.5. + downsampling_factor : int, optional + The downsampling factor for the masks. Default is None. + erosion_dilation : bool, optional + Flag indicating whether to perform erosion and dilation on the masks during upscaling. Default is True. + smoothing_kernel_size : int, optional + The size of the smoothing kernel for upscaling. Default is 7. + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. + + Attributes + ---------- + filtering_threshold : float + The threshold for filtering cytosol IDs based on the proportion of overlapping area with the nucleus. + downsample : bool + Flag indicating whether downsampling is enabled. + downsampling_factor : int + The downsampling factor for the masks. + erosion_dilation : bool + Flag indicating whether to perform erosion and dilation on the masks during upscaling. + smoothing_kernel_size : int + The size of the smoothing kernel for upscaling. + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + nuclei_discard_list : list + A list of nucleus IDs to be discarded. + cytosol_discard_list : list + A list of cytosol IDs to be discarded. + nucleus_lookup_dict : dict + A dictionary mapping nucleus IDs to matched cytosol IDs after filtering. + + Methods + ------- + load_masks(nucleus_mask, cytosol_mask) + Load the nucleus and cytosol masks. + update_cytosol_mask(cytosol_mask) + Update the cytosol mask based on the matched nucleus-cytosol pairs. + update_masks() + Update the nucleus and cytosol masks after filtering. + match_nucleus_id(nucleus_id) + Match the given nucleus ID to a cytosol ID based on the overlapping area. + initialize_lookup_table() + Initialize the lookup table by matching all nucleus IDs to cytosol IDs. + count_cytosol_occurances() + Count the occurrences of each cytosol ID in the lookup table. + check_for_unassigned_cytosols() + Check for unassigned cytosol IDs in the cytosol mask. + identify_multinucleated_cells() + Identify and discard multinucleated cells from the lookup table. + cleanup_filtering_lists() + Cleanup the discard lists by removing duplicate entries. + cleanup_lookup_dictionary() + Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs. + generate_lookup_table(nucleus_mask, cytosol_mask) + Generate the lookup table by performing all necessary steps. + filter(nucleus_mask, cytosol_mask) + Filter the nucleus and cytosol masks based on the matching results. + + """ + def __init__( self, filtering_threshold=0.5, @@ -428,31 +502,46 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) + + #set relevant parameters self.filtering_threshold = filtering_threshold - #set up downsampling + #set downsampling parameters if downsampling_factor is not None: - self.downsample = True + self.downsample = True self.downsampling_factor = downsampling_factor self.erosion_dilation = erosion_dilation self.smoothing_kernel_size = smoothing_kernel_size else: self.downsample = False - + + #initialize placeholders for masks self.nucleus_mask = None self.cytosol_mask = None + # initialize datastructures for saving results self._nucleus_lookup_dict = {} self.nuclei_discard_list = [] self.cytosol_discard_list = [] + self.nucleus_lookup_dict = None def load_masks(self, nucleus_mask, cytosol_mask): - #masks are only loaded once (if already loaded nothing happens) + """ + Load the nucleus and cytosol masks into their placeholders. + This function only loads the masks (and downsamples them if necessary) if this has not already been performed. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + """ if self.downsample: if self.nucleus_mask is None: self.nucleus_mask = self.downsample_mask(nucleus_mask) if self.cytosol_mask is None: - self.cytosol_mask = self.downsample_mask(cytosol_mask) + self.cytosol_mask = self.downsample_mask(cytosol_mask) else: if self.nucleus_mask is None: self.nucleus_mask = nucleus_mask @@ -460,22 +549,40 @@ def load_masks(self, nucleus_mask, cytosol_mask): self.cytosol_mask = cytosol_mask def update_cytosol_mask(self, cytosol_mask): - # now we have all the nucleus cytosol pairs we can filter the masks + """ + Update the cytosol mask based on the matched nucleus-cytosol pairs. + + Parameters + ---------- + cytosol_mask : numpy.ndarray + The cytosol mask. + + Returns + ------- + numpy.ndarray + The updated cytosol mask. + """ updated = np.zeros_like(cytosol_mask, dtype=bool) for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items(): - # set the cytosol pixels to the nucleus_id if not previously updated - condition = np.logical_and( - cytosol_mask == cytosol_id, ~updated - ) - cytosol_mask[condition] = nucleus_id - updated = np.logical_or( - updated, condition - ) - return(cytosol_mask) + condition = np.logical_and( + cytosol_mask == cytosol_id, ~updated + ) + cytosol_mask[condition] = nucleus_id + updated = np.logical_or( + updated, condition + ) + return cytosol_mask def update_masks(self): + """ + Update the nucleus and cytosol masks after filtering. + Returns + ------- + tuple + A tuple containing the updated nucleus mask and cytosol mask. + """ nucleus_mask = self.update_mask(self.nucleus_mask, self.nuclei_discard_list) cytosol_mask = self.update_cytosol_mask(self.cytosol_mask) @@ -491,32 +598,23 @@ def match_nucleus_id(self, nucleus_id): Parameters ---------- - nucleus_mask : numpy.ndarray - The nucleus mask. - cytosol_mask : numpy.ndarray - The cytosol mask. nucleus_id : int The nucleus ID to be matched. Returns ------- - int - The matched cytosol ID. + int or None + The matched cytosol ID, or None if no match is found. """ - # get the coordinates of the nucleus nucleus_pixels = np.where(self.nucleus_mask == nucleus_id) - - # check if those indices are not background in the cytosol mask potential_cytosol = self.cytosol_mask[nucleus_pixels] - # if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus if np.all(potential_cytosol != 0): unique_cytosol, counts = np.unique(potential_cytosol, return_counts=True) all_counts = np.sum(counts) cytosol_proportions = counts / all_counts if np.any(cytosol_proportions >= self.filtering_threshold): - # get the cytosol_id with max proportion cytosol_id = unique_cytosol[ np.argmax(cytosol_proportions >= self.filtering_threshold) ] @@ -534,21 +632,29 @@ def match_nucleus_id(self, nucleus_id): return None def initialize_lookup_table(self): + """ + Initialize the lookup table by matching all nucleus IDs to cytosol IDs. + """ all_nucleus_ids = self.get_unique_ids(self.nucleus_mask) for nucleus_id in all_nucleus_ids: self.match_nucleus_id(nucleus_id) def count_cytosol_occurances(self): + """ + Count the occurrences of each cytosol ID in the lookup table. + """ cytosol_count = defaultdict(int) - # Count the occurrences of each cytosol value for cytosol in self._nucleus_lookup_dict.values(): cytosol_count[cytosol] += 1 self.cytosol_count = cytosol_count def check_for_unassigned_cytosols(self): + """ + Check for unassigned cytosol IDs in the cytosol mask. + """ all_cytosol_ids = self.get_unique_ids(self.cytosol_mask) for cytosol_id in all_cytosol_ids: @@ -556,16 +662,25 @@ def check_for_unassigned_cytosols(self): self.cytosol_discard_list.append(cytosol_id) def identify_multinucleated_cells(self): + """ + Identify and discard multinucleated cells from the lookup table. + """ for nucleus, cytosol in self._nucleus_lookup_dict.items(): if self.cytosol_count[cytosol] > 1: self.nuclei_discard_list.append(nucleus) self.cytosol_discard_list.append(cytosol) def cleanup_filtering_lists(self): + """ + Cleanup the discard lists by removing duplicate entries. + """ self.nuclei_discard_list = list(set(self.nuclei_discard_list)) self.cytosol_discard_list = list(set(self.cytosol_discard_list)) def cleanup_lookup_dictionary(self): + """ + Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs. + """ _cleanup = [] for nucleus_id, cytosol_id in self._nucleus_lookup_dict.items(): if nucleus_id in self.nuclei_discard_list: @@ -577,6 +692,21 @@ def cleanup_lookup_dictionary(self): del self._nucleus_lookup_dict[nucleus] def generate_lookup_table(self, nucleus_mask, cytosol_mask): + """ + Generate the lookup table by performing all necessary steps. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + + Returns + ------- + dict + The lookup table mapping nucleus IDs to matched cytosol IDs. + """ self.load_masks(nucleus_mask, cytosol_mask) self.initialize_lookup_table() self.count_cytosol_occurances() @@ -585,13 +715,29 @@ def generate_lookup_table(self, nucleus_mask, cytosol_mask): self.cleanup_filtering_lists() self.cleanup_lookup_dictionary() - self.nucleus_lookup_dict = self._nucleus_lookup_dict #save final result to a new variable name + self.nucleus_lookup_dict = self._nucleus_lookup_dict - return(self.nucleus_lookup_dict) + return self.nucleus_lookup_dict def filter(self, nucleus_mask, cytosol_mask): - + """ + Filter the nucleus and cytosol masks based on the matching results and return the updated masks. + + Parameters + ---------- + nucleus_mask : numpy.ndarray + The nucleus mask. + cytosol_mask : numpy.ndarray + The cytosol mask. + + Returns + ------- + tuple + A tuple containing the updated nucleus mask and cytosol mask after filtering. + """ self.load_masks(nucleus_mask, cytosol_mask) - self.generate_lookup_table(nucleus_mask, cytosol_mask) + + if self.nucleus_lookup_dict is None: + self.generate_lookup_table(nucleus_mask, cytosol_mask) return self.update_masks() \ No newline at end of file From 4edefad4b6b9cedf53fa3cc74823a717ebd2b1ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:15:57 +0200 Subject: [PATCH 07/18] remove potential duplicates from cleanup function before processing dictionary --- src/sparcscore/processing/filtering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index b8d06059..c6e31a6b 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -688,6 +688,8 @@ def cleanup_lookup_dictionary(self): if cytosol_id in self.cytosol_discard_list: _cleanup.append(nucleus_id) + #ensure we have no duplicate entries + _cleanup = list(set(_cleanup)) for nucleus in _cleanup: del self._nucleus_lookup_dict[nucleus] From 218d9363d761a109c97eb69a915d15b2b7a22977 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:42:14 +0200 Subject: [PATCH 08/18] add log messages tracking how many cells were removed through filtering --- src/sparcscore/pipeline/workflows.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index 0159fb1a..106d0d3c 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -951,10 +951,17 @@ def cellpose_segmentation(self, input_image): else: + self.log( + " Performing filtering to match Cytosol and Nucleus IDs." + ) + # perform filtering to remove cytosols which do not have a corresponding nucleus filter = MatchNucleusCytosolIds(filtering_threshold = self.config["filtering_threshold"]) masks_nucleus, masks_cytosol = filter.filter(masks_nucleus, masks_cytosol) + self.log(f"Removed {len(filter.nuclei_discard_list)} nuclei and {len(filter.cytosol_discard_list)} cytosols due to filtering.") + self.log(f"After filtering, {len(filter.nucleus_lookup_dict)} matching nuclei and cytosol masks remain.") + if self.debug: # plot nucleus and cytosol masks before and after filtering fig, axs = plt.subplots(2, 2, figsize=(8, 8)) From dfa5d40bc51086b37c61ea0c67293b1fa6d23ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 12:28:17 +0200 Subject: [PATCH 09/18] remove unused variable --- src/sparcscore/pipeline/workflows.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index ab37745f..b61b28d0 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1378,10 +1378,6 @@ def _finalize_segmentation_results(self, size_padding): ) _, x, y = size_padding - segmentation_size = ( - x, - y, - ) # return to same size as original input image but adjust number of channels expected cyto_seg = self.maps["cytosol_segmentation"] cyto_seg = cyto_seg.repeat(self.config["downsampling_factor"], axis=0).repeat( From ecf30b90b47b1ab9842ffec0900d0c8d121edc70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 13:59:21 +0200 Subject: [PATCH 10/18] add function to get downsampling parameters --- src/sparcscore/pipeline/workflows.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index b61b28d0..665c2130 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1057,6 +1057,7 @@ def _finalize_segmentation_results(self, size_padding): channels = np.stack(required_maps).astype(np.uint16) _seg_size = self.maps["nucleus_segmentation"].shape + self.log( f"Segmentation size after downsampling before resize to original dimensions: {_seg_size}" ) @@ -1064,28 +1065,24 @@ def _finalize_segmentation_results(self, size_padding): # rescale downsampled segmentation results to original size by repeating pixels _, x, y = size_padding + N, smoothing_kernel_size = _get_downsampling_parameters() + nuc_seg = self.maps["nucleus_segmentation"] - nuc_seg = nuc_seg.repeat(self.config["downsampling_factor"], axis=0).repeat( - self.config["downsampling_factor"], axis=1 - ) + nuc_seg = nuc_seg.repeat(N, axis=0).repeat(N, axis=1) cyto_seg = self.maps["cytosol_segmentation"] - cyto_seg = cyto_seg.repeat(self.config["downsampling_factor"], axis=0).repeat( - self.config["downsampling_factor"], axis=1 - ) + cyto_seg = cyto_seg.repeat(N, axis=0).repeat(N, axis=1) # perform erosion and dilation for smoothing - nuc_seg = erosion(nuc_seg, footprint=disk(self.config["smoothing_kernel_size"])) + nuc_seg = erosion(nuc_seg, footprint=disk(smoothing_kernel_size)) nuc_seg = dilation( - nuc_seg, footprint=disk(self.config["smoothing_kernel_size"]) - ) + nuc_seg, footprint=disk(smoothing_kernel_size + 1) + ) # dilate 1 more than eroded to ensure that we do not lose any pixels - cyto_seg = erosion( - cyto_seg, footprint=disk(self.config["smoothing_kernel_size"]) - ) + cyto_seg = erosion(cyto_seg, footprint=disk(smoothing_kernel_size)) cyto_seg = dilation( - cyto_seg, footprint=disk(self.config["smoothing_kernel_size"]) - ) + cyto_seg, footprint=disk(smoothing_kernel_size + 1) + ) # dilate 1 more than eroded to ensure that we do not lose any pixels # combine masks into one stack segmentation = np.stack([nuc_seg, cyto_seg]).astype(np.uint32) From 643960d2a059445ea1ac168f78f5f8245233960b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:00:28 +0200 Subject: [PATCH 11/18] implement sanity check to catch cases #4 --- src/sparcscore/pipeline/workflows.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index 665c2130..ee0dacd6 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1068,9 +1068,13 @@ def _finalize_segmentation_results(self, size_padding): N, smoothing_kernel_size = _get_downsampling_parameters() nuc_seg = self.maps["nucleus_segmentation"] + n_nuclei = len( + np.unique(nuc_seg)[0] + ) # get number of objects in mask for sanity checking nuc_seg = nuc_seg.repeat(N, axis=0).repeat(N, axis=1) cyto_seg = self.maps["cytosol_segmentation"] + n_cytosols = len(np.unique(cyto_seg)[0]) cyto_seg = cyto_seg.repeat(N, axis=0).repeat(N, axis=1) # perform erosion and dilation for smoothing @@ -1084,6 +1088,23 @@ def _finalize_segmentation_results(self, size_padding): cyto_seg, footprint=disk(smoothing_kernel_size + 1) ) # dilate 1 more than eroded to ensure that we do not lose any pixels + # sanity check to make sure that smoothing does not remove masks + if len(np.unique(nuc_seg)[0]) != n_nuclei: + self.log( + "Error. Number of nuclei in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." + ) + sys.exit( + "Error. Number of nuclei in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." + ) + + if len(np.unique(cyto_seg)[0]) != n_cytosols: + self.log( + "Error. Number of cytosols in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." + ) + sys.exit( + "Error. Number of cytosols in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." + ) + # combine masks into one stack segmentation = np.stack([nuc_seg, cyto_seg]).astype(np.uint32) del cyto_seg, nuc_seg From d0c0ef699bef1bb7793a650814b7921456b9d8eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:18:10 +0200 Subject: [PATCH 12/18] add missing part of previous commit --- src/sparcscore/pipeline/workflows.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index ee0dacd6..a5d9aeb5 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1046,6 +1046,22 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _finalize_segmentation_results(self, size_padding): + def _get_downsampling_parameters(self): + N = self.config["downsampling_factor"] + if "smoothing_kernel_size" in self.config.keys(): + smoothing_kernel_size = self.config["smoothing_kernel_size"] + + if smoothing_kernel_size > N: + self.log( + "Warning: Smoothing Kernel size is larger than the downsampling factor. This can lead to issues during smoothing where segmentation masks are lost. Please ensure to double check your results." + ) + + else: + self.log( + "Smoothing Kernel size not explicitly defined. Will calculate a default value based on the downsampling factor." + ) + smoothing_kernel_size = N + # nuclear and cyotosolic channels are required (used for segmentation) required_maps = [self.maps["normalized"][0], self.maps["normalized"][1]] From 623ef03e02c809975c88976976d5f16d22c8da64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:27:29 +0200 Subject: [PATCH 13/18] relocate function to calculate downsampling parameters --- src/sparcscore/pipeline/workflows.py | 60 ++++++++++++++++------------ 1 file changed, 35 insertions(+), 25 deletions(-) mode change 100644 => 100755 src/sparcscore/pipeline/workflows.py diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py old mode 100644 new mode 100755 index a5d9aeb5..b3f8b9ea --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -913,8 +913,10 @@ def cellpose_segmentation(self, input_image): masks_nucleus = filter_nucleus.filter(masks_nucleus) - self.log(f"Removed {len(filter_nucleus.ids_to_remove)} nuclei as they fell outside of the threshold range {filter_nucleus.threshold}.") - + self.log( + f"Removed {len(filter_nucleus.ids_to_remove)} nuclei as they fell outside of the threshold range {filter_nucleus.threshold}." + ) + # perform filtering for cytosol size thresholds, confidence_interval = self.get_params_cellsize_filtering( "cytosol" @@ -939,8 +941,10 @@ def cellpose_segmentation(self, input_image): ) masks_cytosol = filter_cytosol.filter(masks_cytosol) - self.log(f"Removed {len(filter_cytosol.ids_to_remove)} cytosols as they fell outside of the threshold range {filter_cytosol.threshold}.") - + self.log( + f"Removed {len(filter_cytosol.ids_to_remove)} cytosols as they fell outside of the threshold range {filter_cytosol.threshold}." + ) + ###################### ### Perform Filtering match cytosol and nucleus IDs if applicable ###################### @@ -951,17 +955,20 @@ def cellpose_segmentation(self, input_image): ) else: - - self.log( - "Performing filtering to match Cytosol and Nucleus IDs." - ) + self.log("Performing filtering to match Cytosol and Nucleus IDs.") # perform filtering to remove cytosols which do not have a corresponding nucleus - filter = MatchNucleusCytosolIds(filtering_threshold = self.config["filtering_threshold"]) + filter = MatchNucleusCytosolIds( + filtering_threshold=self.config["filtering_threshold"] + ) masks_nucleus, masks_cytosol = filter.filter(masks_nucleus, masks_cytosol) - self.log(f"Removed {len(filter.nuclei_discard_list)} nuclei and {len(filter.cytosol_discard_list)} cytosols due to filtering.") - self.log(f"After filtering, {len(filter.nucleus_lookup_dict)} matching nuclei and cytosol masks remain.") + self.log( + f"Removed {len(filter.nuclei_discard_list)} nuclei and {len(filter.cytosol_discard_list)} cytosols due to filtering." + ) + self.log( + f"After filtering, {len(filter.nucleus_lookup_dict)} matching nuclei and cytosol masks remain." + ) if self.debug: # plot nucleus and cytosol masks before and after filtering @@ -1044,24 +1051,26 @@ class ShardedCytosolSegmentationCellpose(ShardedSegmentation): class CytosolSegmentationDownsamplingCellpose(CytosolSegmentationCellpose): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + def _get_downsampling_parameters(self): + N = self.config["downsampling_factor"] + if "smoothing_kernel_size" in self.config.keys(): + smoothing_kernel_size = self.config["smoothing_kernel_size"] - def _finalize_segmentation_results(self, size_padding): - def _get_downsampling_parameters(self): - N = self.config["downsampling_factor"] - if "smoothing_kernel_size" in self.config.keys(): - smoothing_kernel_size = self.config["smoothing_kernel_size"] - - if smoothing_kernel_size > N: - self.log( - "Warning: Smoothing Kernel size is larger than the downsampling factor. This can lead to issues during smoothing where segmentation masks are lost. Please ensure to double check your results." - ) - - else: + if smoothing_kernel_size > N: self.log( - "Smoothing Kernel size not explicitly defined. Will calculate a default value based on the downsampling factor." + "Warning: Smoothing Kernel size is larger than the downsampling factor. This can lead to issues during smoothing where segmentation masks are lost. Please ensure to double check your results." ) - smoothing_kernel_size = N + else: + self.log( + "Smoothing Kernel size not explicitly defined. Will calculate a default value based on the downsampling factor." + ) + smoothing_kernel_size = N + + return N, smoothing_kernel_size + + def _finalize_segmentation_results(self, size_padding): # nuclear and cyotosolic channels are required (used for segmentation) required_maps = [self.maps["normalized"][0], self.maps["normalized"][1]] @@ -1407,6 +1416,7 @@ def _finalize_segmentation_results(self, size_padding): channels = np.stack(required_maps).astype(np.uint16) _seg_size = self.maps["cytosol_segmentation"].shape + self.log( f"Segmentation size after downsampling before resize to original dimensions: {_seg_size}" ) From 0c09096f3ae05a60ada76d6002b14256c88acbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:29:38 +0200 Subject: [PATCH 14/18] fix function call --- src/sparcscore/pipeline/workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index b3f8b9ea..4e0dd361 100755 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1090,7 +1090,7 @@ def _finalize_segmentation_results(self, size_padding): # rescale downsampled segmentation results to original size by repeating pixels _, x, y = size_padding - N, smoothing_kernel_size = _get_downsampling_parameters() + N, smoothing_kernel_size = self._get_downsampling_parameters() nuc_seg = self.maps["nucleus_segmentation"] n_nuclei = len( From d2939091c91335c798a9725a1a27b0fecd75cc4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:31:58 +0200 Subject: [PATCH 15/18] fix small bug in incorrect calculationg of number of classes --- src/sparcscore/pipeline/workflows.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index 4e0dd361..2cdfe4bc 100755 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1093,13 +1093,11 @@ def _finalize_segmentation_results(self, size_padding): N, smoothing_kernel_size = self._get_downsampling_parameters() nuc_seg = self.maps["nucleus_segmentation"] - n_nuclei = len( - np.unique(nuc_seg)[0] - ) # get number of objects in mask for sanity checking + n_nuclei = len(np.unique(nuc_seg)) # get number of objects in mask for sanity checking nuc_seg = nuc_seg.repeat(N, axis=0).repeat(N, axis=1) cyto_seg = self.maps["cytosol_segmentation"] - n_cytosols = len(np.unique(cyto_seg)[0]) + n_cytosols = len(np.unique(cyto_seg)) cyto_seg = cyto_seg.repeat(N, axis=0).repeat(N, axis=1) # perform erosion and dilation for smoothing @@ -1114,7 +1112,7 @@ def _finalize_segmentation_results(self, size_padding): ) # dilate 1 more than eroded to ensure that we do not lose any pixels # sanity check to make sure that smoothing does not remove masks - if len(np.unique(nuc_seg)[0]) != n_nuclei: + if len(np.unique(nuc_seg)) != n_nuclei: self.log( "Error. Number of nuclei in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." ) @@ -1122,7 +1120,7 @@ def _finalize_segmentation_results(self, size_padding): "Error. Number of nuclei in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." ) - if len(np.unique(cyto_seg)[0]) != n_cytosols: + if len(np.unique(cyto_seg)) != n_cytosols: self.log( "Error. Number of cytosols in segmentation mask changed after smoothing. This should not happen. Ensure that you have chosen adequate smoothing parameters or use the defaults." ) From b00b7fc7e8bbe697c7649fa6521b1b3eae692fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Wed, 5 Jun 2024 09:25:55 +0200 Subject: [PATCH 16/18] fix typo --- src/sparcscore/pipeline/segmentation.py | 2 +- src/sparcscore/pipeline/workflows.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparcscore/pipeline/segmentation.py b/src/sparcscore/pipeline/segmentation.py index dd8a6e8c..a9bf1b55 100644 --- a/src/sparcscore/pipeline/segmentation.py +++ b/src/sparcscore/pipeline/segmentation.py @@ -280,7 +280,7 @@ def save_segmentation_zarr(self, labels=None): loc = parse_url(path, mode="w").store group = zarr.group(store=loc) - segmentation_names = ["nucleus", "cyotosol"] + segmentation_names = ["nucleus", "cytosol"] # check if segmentation names already exist if so delete for seg_names in segmentation_names: diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index 2cdfe4bc..e9c4a74d 100755 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -1071,7 +1071,7 @@ def _get_downsampling_parameters(self): return N, smoothing_kernel_size def _finalize_segmentation_results(self, size_padding): - # nuclear and cyotosolic channels are required (used for segmentation) + # nuclear and cytosolic channels are required (used for segmentation) required_maps = [self.maps["normalized"][0], self.maps["normalized"][1]] # Feature maps are all further channel which contain additional phenotypes e.g. for classification From 28346c41542b469850c62e601f0253108913d1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:21:34 +0200 Subject: [PATCH 17/18] add additional log message --- src/sparcscore/pipeline/segmentation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sparcscore/pipeline/segmentation.py b/src/sparcscore/pipeline/segmentation.py index a9bf1b55..85a46ce2 100644 --- a/src/sparcscore/pipeline/segmentation.py +++ b/src/sparcscore/pipeline/segmentation.py @@ -657,6 +657,7 @@ def calculate_sharding_plan(self, image_size): def cleanup_shards(self, sharding_plan): file_identifiers_plots = [".png", ".tif", ".tiff", ".jpg", ".jpeg", ".pdf"] + self.log("Moving generated plots from shard directory to main directory.") for i, window in enumerate(sharding_plan): local_shard_directory = os.path.join(self.shard_directory, str(i)) for file in os.listdir(local_shard_directory): @@ -945,7 +946,6 @@ def process(self, input_image): def complete_segmentation(self, input_image): self.save_zarr = False - self.save_input_image(input_image) self.shard_directory = os.path.join(self.directory, self.DEFAULT_SHARD_FOLDER) # check to make sure that the shard directory exisits, if not exit and return error @@ -954,6 +954,9 @@ def complete_segmentation(self, input_image): "No Shard Directory found for the given project. Can not complete a segmentation which has not started. Please rerun the segmentation method." ) + #save input image to segmentation.h5 + self.save_input_image(input_image) + # check to see which tiles are incomplete tile_directories = os.listdir(self.shard_directory) incomplete_indexes = [] From 0b62d8ae6f129c35bbc69c31a8aad0c01e9bdc72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:42:30 +0200 Subject: [PATCH 18/18] rename functions more consistently --- src/sparcscore/processing/filtering.py | 164 +++++++++++++++---------- 1 file changed, 101 insertions(+), 63 deletions(-) diff --git a/src/sparcscore/processing/filtering.py b/src/sparcscore/processing/filtering.py index c6e31a6b..aa6603ff 100644 --- a/src/sparcscore/processing/filtering.py +++ b/src/sparcscore/processing/filtering.py @@ -14,11 +14,11 @@ class BaseFilter(Logable): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def get_unique_ids(self, mask): return np.unique(mask)[1:] # to remove the background - def update_mask(self, mask, ids_to_remove): + def get_updated_mask(self, mask, ids_to_remove): """ Update the given mask by setting the values corresponding to the specified IDs to 0. @@ -36,10 +36,10 @@ def update_mask(self, mask, ids_to_remove): """ return np.where(np.isin(mask, ids_to_remove), 0, mask) - def downsample_mask(self, mask): + def get_downsampled_mask(self, mask): return downsample_img_pxs(mask, N=self.downsampling_factor) - def upscale_mask_basic(self, mask, erosion_dilation=False): + def get_upscaled_mask_basic(self, mask, erosion_dilation=False): mask = mask.repeat(self.downsampling_factor, axis=0).repeat( self.downsampling_factor, axis=1 ) @@ -50,6 +50,7 @@ def upscale_mask_basic(self, mask, erosion_dilation=False): return mask + class SizeFilter(BaseFilter): """ Filter class for removing objects from a mask based on their size. @@ -80,6 +81,15 @@ class SizeFilter(BaseFilter): The directory to save the generated figures. If not provided, the current working directory will be used. confidence_interval : float, optional The confidence interval for calculating the filtering threshold. Default is 0.95. + n_components : int, optional + The number of components in the Gaussian mixture model. Default is 1. + population_to_keep : str, optional + For multipopulation models this parameter determines which population should be kept. Options are "largest", "smallest", "mostcommon", "leastcommon". Default is "mostcommon". + If set to "largest" or "smallest", the model is chosen which has the largest or smallest mean. If set to "mostcommon" or "leastcommon", the model is chosen whose population is least or most common. + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. Examples -------- @@ -102,7 +112,7 @@ def __init__( plot_qc=True, directory=None, confidence_interval=0.95, - n_components=2, + n_components=1, population_to_keep="largest", *args, **kwargs, @@ -125,7 +135,7 @@ def __init__( self.ids_to_remove = None - def plot_gaussian_model( + def _get_gaussian_model_plot( self, counts, means, @@ -213,7 +223,7 @@ def plot_gaussian_model( return fig - def plot_histogram( + def _get_histogram_plot( self, values, label=None, @@ -259,7 +269,40 @@ def plot_histogram( return fig - def calculate_filtering_threshold(self, counts): + def _get_index_population(self, means, weights): + """ + Returns the index of the model that matches the population of cells to be kept. + + Parameters + ---------- + means : numpy.ndarray + An array containing the means of the models. + weights : numpy.ndarray + An array containing the weights of the models. + + Returns + ------- + int + The index of the model that matches the population criteria. + + Notes + ----- + The function determines the index of the model based on the population of cells that should be kept. + The population criteria can be set to "largest", "smallest", "mostcommon", or "leastcommon". + If the population criteria is set to "largest" or "smallest", the index is determined based on the means array. + If the population criteria is set to "mostcommon" or "leastcommon", the index is determined based on the weights array. + """ + if self.population_to_keep == "largest": + idx = np.argmax(means) + elif self.population_to_keep == "smallest": + idx = np.argmin(means) + elif self.population_to_keep == "mostcommon": + idx = np.argmax(weights) + elif self.population_to_keep == "leastcommon": + idx = np.argmin(weights) + return idx + + def _calculate_filtering_threshold(self, counts): """ Calculate the filtering thresholds for the given counts. @@ -306,10 +349,7 @@ def calculate_filtering_threshold(self, counts): weights = gmm.weights_ # get index of the model which matches to the population of cells that should be kept - if self.population_to_keep == "largest": - idx = np.argmax(means) - elif self.population_to_keep == "smallest": - idx = np.argmin(means) + idx = self._get_index_population(means, weights) # calculate the thresholds for the selected model using the given confidence interval mu = means[idx] @@ -324,7 +364,7 @@ def calculate_filtering_threshold(self, counts): threshold = (lower_threshold, upper_threshold) - fig = self.plot_gaussian_model( + fig = self._get_gaussian_model_plot( counts=data, means=means, variances=variances, @@ -343,9 +383,8 @@ def calculate_filtering_threshold(self, counts): self.log( f"Calculated threshold for {self.label} with {self.confidence_interval * 100}% confidence interval: {self.threshold}" ) - return self.threshold - def get_ids_to_remove(self, input_mask): + def _get_ids_to_remove(self, input_mask): """ Get the IDs to remove from the input mask based on the filtering threshold. @@ -371,7 +410,7 @@ def get_ids_to_remove(self, input_mask): counts = np.unique(input_mask, return_counts=True) pixel_counts = counts[1][1:] - fig = self.plot_histogram(pixel_counts, self.label) + fig = self._get_histogram_plot(pixel_counts, self.label) plt.close(fig) if self.plot_qc: @@ -379,7 +418,7 @@ def get_ids_to_remove(self, input_mask): # automatically calculate filtering threshold if not provided if self.filter_threshold is None: - self.filter_threshold = self.calculate_filtering_threshold(pixel_counts) + self._calculate_filtering_threshold(pixel_counts) ids_remove = [] _ids = counts[0][1:][np.where(pixel_counts < self.filter_threshold[0])] @@ -414,9 +453,10 @@ def filter(self, input_mask): """ if self.ids_to_remove is None: - self.get_ids_to_remove(input_mask) + self._get_ids_to_remove(input_mask) return self.update_mask(input_mask, self.ids_to_remove) + class MatchNucleusCytosolIds(BaseFilter): """ Filter class for matching nucleusIDs to their matching cytosol IDs and removing all classes from the given @@ -503,29 +543,29 @@ def __init__( ): super().__init__(*args, **kwargs) - #set relevant parameters + # set relevant parameters self.filtering_threshold = filtering_threshold - #set downsampling parameters + # set downsampling parameters if downsampling_factor is not None: - self.downsample = True + self.downsample = True self.downsampling_factor = downsampling_factor self.erosion_dilation = erosion_dilation self.smoothing_kernel_size = smoothing_kernel_size else: self.downsample = False - - #initialize placeholders for masks + + # initialize placeholders for masks self.nucleus_mask = None self.cytosol_mask = None - + # initialize datastructures for saving results self._nucleus_lookup_dict = {} self.nuclei_discard_list = [] self.cytosol_discard_list = [] self.nucleus_lookup_dict = None - def load_masks(self, nucleus_mask, cytosol_mask): + def _load_masks(self, nucleus_mask, cytosol_mask): """ Load the nucleus and cytosol masks into their placeholders. This function only loads the masks (and downsamples them if necessary) if this has not already been performed. @@ -547,8 +587,8 @@ def load_masks(self, nucleus_mask, cytosol_mask): self.nucleus_mask = nucleus_mask if self.cytosol_mask is None: self.cytosol_mask = cytosol_mask - - def update_cytosol_mask(self, cytosol_mask): + + def _get_updated_cytosol_mask(self, cytosol_mask): """ Update the cytosol mask based on the matched nucleus-cytosol pairs. @@ -565,16 +605,12 @@ def update_cytosol_mask(self, cytosol_mask): updated = np.zeros_like(cytosol_mask, dtype=bool) for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items(): - condition = np.logical_and( - cytosol_mask == cytosol_id, ~updated - ) + condition = np.logical_and(cytosol_mask == cytosol_id, ~updated) cytosol_mask[condition] = nucleus_id - updated = np.logical_or( - updated, condition - ) + updated = np.logical_or(updated, condition) return cytosol_mask - - def update_masks(self): + + def _get_updated_masks(self): """ Update the nucleus and cytosol masks after filtering. @@ -588,11 +624,13 @@ def update_masks(self): if self.downsample: nucleus_mask = self.upscale_mask_basic(nucleus_mask, self.erosion_dilation) - cytosol_mask = self.upscale_mask_basic(self.cytosol_mask, self.erosion_dilation) - + cytosol_mask = self.upscale_mask_basic( + self.cytosol_mask, self.erosion_dilation + ) + return nucleus_mask, cytosol_mask - - def match_nucleus_id(self, nucleus_id): + + def _match_nucleus_id(self, nucleus_id): """ Match the given nucleus ID to a cytosol ID based on the overlapping area. @@ -630,8 +668,8 @@ def match_nucleus_id(self, nucleus_id): else: self.nuclei_discard_list.append(nucleus_id) return None - - def initialize_lookup_table(self): + + def _initialize_lookup_table(self): """ Initialize the lookup table by matching all nucleus IDs to cytosol IDs. """ @@ -639,8 +677,8 @@ def initialize_lookup_table(self): for nucleus_id in all_nucleus_ids: self.match_nucleus_id(nucleus_id) - - def count_cytosol_occurances(self): + + def _count_cytosol_occurances(self): """ Count the occurrences of each cytosol ID in the lookup table. """ @@ -648,10 +686,10 @@ def count_cytosol_occurances(self): for cytosol in self._nucleus_lookup_dict.values(): cytosol_count[cytosol] += 1 - + self.cytosol_count = cytosol_count - - def check_for_unassigned_cytosols(self): + + def _check_for_unassigned_cytosols(self): """ Check for unassigned cytosol IDs in the cytosol mask. """ @@ -660,8 +698,8 @@ def check_for_unassigned_cytosols(self): for cytosol_id in all_cytosol_ids: if cytosol_id not in self._nucleus_lookup_dict.values(): self.cytosol_discard_list.append(cytosol_id) - - def identify_multinucleated_cells(self): + + def _identify_multinucleated_cells(self): """ Identify and discard multinucleated cells from the lookup table. """ @@ -670,14 +708,14 @@ def identify_multinucleated_cells(self): self.nuclei_discard_list.append(nucleus) self.cytosol_discard_list.append(cytosol) - def cleanup_filtering_lists(self): + def _cleanup_filtering_lists(self): """ Cleanup the discard lists by removing duplicate entries. """ self.nuclei_discard_list = list(set(self.nuclei_discard_list)) self.cytosol_discard_list = list(set(self.cytosol_discard_list)) - def cleanup_lookup_dictionary(self): + def _cleanup_lookup_dictionary(self): """ Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs. """ @@ -687,13 +725,13 @@ def cleanup_lookup_dictionary(self): _cleanup.append(nucleus_id) if cytosol_id in self.cytosol_discard_list: _cleanup.append(nucleus_id) - - #ensure we have no duplicate entries + + # ensure we have no duplicate entries _cleanup = list(set(_cleanup)) for nucleus in _cleanup: del self._nucleus_lookup_dict[nucleus] - def generate_lookup_table(self, nucleus_mask, cytosol_mask): + def get_lookup_table(self, nucleus_mask, cytosol_mask): """ Generate the lookup table by performing all necessary steps. @@ -710,12 +748,12 @@ def generate_lookup_table(self, nucleus_mask, cytosol_mask): The lookup table mapping nucleus IDs to matched cytosol IDs. """ self.load_masks(nucleus_mask, cytosol_mask) - self.initialize_lookup_table() - self.count_cytosol_occurances() - self.check_for_unassigned_cytosols() - self.identify_multinucleated_cells() - self.cleanup_filtering_lists() - self.cleanup_lookup_dictionary() + self._initialize_lookup_table() + self._count_cytosol_occurances() + self._check_for_unassigned_cytosols() + self._identify_multinucleated_cells() + self._cleanup_filtering_lists() + self._cleanup_lookup_dictionary() self.nucleus_lookup_dict = self._nucleus_lookup_dict @@ -737,9 +775,9 @@ def filter(self, nucleus_mask, cytosol_mask): tuple A tuple containing the updated nucleus mask and cytosol mask after filtering. """ - self.load_masks(nucleus_mask, cytosol_mask) + self._load_masks(nucleus_mask, cytosol_mask) if self.nucleus_lookup_dict is None: - self.generate_lookup_table(nucleus_mask, cytosol_mask) - - return self.update_masks() \ No newline at end of file + self.get_lookup_table(nucleus_mask, cytosol_mask) + + return self._get_updated_masks()