Skip to content

Commit

Permalink
rename functions more consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Jun 5, 2024
1 parent 28346c4 commit 0b62d8a
Showing 1 changed file with 101 additions and 63 deletions.
164 changes: 101 additions & 63 deletions src/sparcscore/processing/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
class BaseFilter(Logable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_unique_ids(self, mask):
return np.unique(mask)[1:] # to remove the background

def update_mask(self, mask, ids_to_remove):
def get_updated_mask(self, mask, ids_to_remove):
"""
Update the given mask by setting the values corresponding to the specified IDs to 0.
Expand All @@ -36,10 +36,10 @@ def update_mask(self, mask, ids_to_remove):
"""
return np.where(np.isin(mask, ids_to_remove), 0, mask)

def downsample_mask(self, mask):
def get_downsampled_mask(self, mask):
return downsample_img_pxs(mask, N=self.downsampling_factor)

def upscale_mask_basic(self, mask, erosion_dilation=False):
def get_upscaled_mask_basic(self, mask, erosion_dilation=False):
mask = mask.repeat(self.downsampling_factor, axis=0).repeat(
self.downsampling_factor, axis=1
)
Expand All @@ -50,6 +50,7 @@ def upscale_mask_basic(self, mask, erosion_dilation=False):

return mask


class SizeFilter(BaseFilter):
"""
Filter class for removing objects from a mask based on their size.
Expand Down Expand Up @@ -80,6 +81,15 @@ class SizeFilter(BaseFilter):
The directory to save the generated figures. If not provided, the current working directory will be used.
confidence_interval : float, optional
The confidence interval for calculating the filtering threshold. Default is 0.95.
n_components : int, optional
The number of components in the Gaussian mixture model. Default is 1.
population_to_keep : str, optional
For multipopulation models this parameter determines which population should be kept. Options are "largest", "smallest", "mostcommon", "leastcommon". Default is "mostcommon".
If set to "largest" or "smallest", the model is chosen which has the largest or smallest mean. If set to "mostcommon" or "leastcommon", the model is chosen whose population is least or most common.
*args
Additional positional arguments.
**kwargs
Additional keyword arguments.
Examples
--------
Expand All @@ -102,7 +112,7 @@ def __init__(
plot_qc=True,
directory=None,
confidence_interval=0.95,
n_components=2,
n_components=1,
population_to_keep="largest",
*args,
**kwargs,
Expand All @@ -125,7 +135,7 @@ def __init__(

self.ids_to_remove = None

def plot_gaussian_model(
def _get_gaussian_model_plot(
self,
counts,
means,
Expand Down Expand Up @@ -213,7 +223,7 @@ def plot_gaussian_model(

return fig

def plot_histogram(
def _get_histogram_plot(
self,
values,
label=None,
Expand Down Expand Up @@ -259,7 +269,40 @@ def plot_histogram(

return fig

def calculate_filtering_threshold(self, counts):
def _get_index_population(self, means, weights):
"""
Returns the index of the model that matches the population of cells to be kept.
Parameters
----------
means : numpy.ndarray
An array containing the means of the models.
weights : numpy.ndarray
An array containing the weights of the models.
Returns
-------
int
The index of the model that matches the population criteria.
Notes
-----
The function determines the index of the model based on the population of cells that should be kept.
The population criteria can be set to "largest", "smallest", "mostcommon", or "leastcommon".
If the population criteria is set to "largest" or "smallest", the index is determined based on the means array.
If the population criteria is set to "mostcommon" or "leastcommon", the index is determined based on the weights array.
"""
if self.population_to_keep == "largest":
idx = np.argmax(means)
elif self.population_to_keep == "smallest":
idx = np.argmin(means)
elif self.population_to_keep == "mostcommon":
idx = np.argmax(weights)
elif self.population_to_keep == "leastcommon":
idx = np.argmin(weights)
return idx

def _calculate_filtering_threshold(self, counts):
"""
Calculate the filtering thresholds for the given counts.
Expand Down Expand Up @@ -306,10 +349,7 @@ def calculate_filtering_threshold(self, counts):
weights = gmm.weights_

# get index of the model which matches to the population of cells that should be kept
if self.population_to_keep == "largest":
idx = np.argmax(means)
elif self.population_to_keep == "smallest":
idx = np.argmin(means)
idx = self._get_index_population(means, weights)

# calculate the thresholds for the selected model using the given confidence interval
mu = means[idx]
Expand All @@ -324,7 +364,7 @@ def calculate_filtering_threshold(self, counts):

threshold = (lower_threshold, upper_threshold)

fig = self.plot_gaussian_model(
fig = self._get_gaussian_model_plot(
counts=data,
means=means,
variances=variances,
Expand All @@ -343,9 +383,8 @@ def calculate_filtering_threshold(self, counts):
self.log(
f"Calculated threshold for {self.label} with {self.confidence_interval * 100}% confidence interval: {self.threshold}"
)
return self.threshold

def get_ids_to_remove(self, input_mask):
def _get_ids_to_remove(self, input_mask):
"""
Get the IDs to remove from the input mask based on the filtering threshold.
Expand All @@ -371,15 +410,15 @@ def get_ids_to_remove(self, input_mask):
counts = np.unique(input_mask, return_counts=True)
pixel_counts = counts[1][1:]

fig = self.plot_histogram(pixel_counts, self.label)
fig = self._get_histogram_plot(pixel_counts, self.label)
plt.close(fig)

if self.plot_qc:
plt.show(fig)

# automatically calculate filtering threshold if not provided
if self.filter_threshold is None:
self.filter_threshold = self.calculate_filtering_threshold(pixel_counts)
self._calculate_filtering_threshold(pixel_counts)

ids_remove = []
_ids = counts[0][1:][np.where(pixel_counts < self.filter_threshold[0])]
Expand Down Expand Up @@ -414,9 +453,10 @@ def filter(self, input_mask):
"""
if self.ids_to_remove is None:
self.get_ids_to_remove(input_mask)
self._get_ids_to_remove(input_mask)
return self.update_mask(input_mask, self.ids_to_remove)


class MatchNucleusCytosolIds(BaseFilter):
"""
Filter class for matching nucleusIDs to their matching cytosol IDs and removing all classes from the given
Expand Down Expand Up @@ -503,29 +543,29 @@ def __init__(
):
super().__init__(*args, **kwargs)

#set relevant parameters
# set relevant parameters
self.filtering_threshold = filtering_threshold

#set downsampling parameters
# set downsampling parameters
if downsampling_factor is not None:
self.downsample = True
self.downsample = True
self.downsampling_factor = downsampling_factor
self.erosion_dilation = erosion_dilation
self.smoothing_kernel_size = smoothing_kernel_size
else:
self.downsample = False
#initialize placeholders for masks

# initialize placeholders for masks
self.nucleus_mask = None
self.cytosol_mask = None

# initialize datastructures for saving results
self._nucleus_lookup_dict = {}
self.nuclei_discard_list = []
self.cytosol_discard_list = []
self.nucleus_lookup_dict = None

def load_masks(self, nucleus_mask, cytosol_mask):
def _load_masks(self, nucleus_mask, cytosol_mask):
"""
Load the nucleus and cytosol masks into their placeholders.
This function only loads the masks (and downsamples them if necessary) if this has not already been performed.
Expand All @@ -547,8 +587,8 @@ def load_masks(self, nucleus_mask, cytosol_mask):
self.nucleus_mask = nucleus_mask
if self.cytosol_mask is None:
self.cytosol_mask = cytosol_mask
def update_cytosol_mask(self, cytosol_mask):

def _get_updated_cytosol_mask(self, cytosol_mask):
"""
Update the cytosol mask based on the matched nucleus-cytosol pairs.
Expand All @@ -565,16 +605,12 @@ def update_cytosol_mask(self, cytosol_mask):
updated = np.zeros_like(cytosol_mask, dtype=bool)

for nucleus_id, cytosol_id in self.nucleus_lookup_dict.items():
condition = np.logical_and(
cytosol_mask == cytosol_id, ~updated
)
condition = np.logical_and(cytosol_mask == cytosol_id, ~updated)
cytosol_mask[condition] = nucleus_id
updated = np.logical_or(
updated, condition
)
updated = np.logical_or(updated, condition)
return cytosol_mask
def update_masks(self):

def _get_updated_masks(self):
"""
Update the nucleus and cytosol masks after filtering.
Expand All @@ -588,11 +624,13 @@ def update_masks(self):

if self.downsample:
nucleus_mask = self.upscale_mask_basic(nucleus_mask, self.erosion_dilation)
cytosol_mask = self.upscale_mask_basic(self.cytosol_mask, self.erosion_dilation)

cytosol_mask = self.upscale_mask_basic(
self.cytosol_mask, self.erosion_dilation
)

return nucleus_mask, cytosol_mask
def match_nucleus_id(self, nucleus_id):

def _match_nucleus_id(self, nucleus_id):
"""
Match the given nucleus ID to a cytosol ID based on the overlapping area.
Expand Down Expand Up @@ -630,28 +668,28 @@ def match_nucleus_id(self, nucleus_id):
else:
self.nuclei_discard_list.append(nucleus_id)
return None
def initialize_lookup_table(self):

def _initialize_lookup_table(self):
"""
Initialize the lookup table by matching all nucleus IDs to cytosol IDs.
"""
all_nucleus_ids = self.get_unique_ids(self.nucleus_mask)

for nucleus_id in all_nucleus_ids:
self.match_nucleus_id(nucleus_id)
def count_cytosol_occurances(self):

def _count_cytosol_occurances(self):
"""
Count the occurrences of each cytosol ID in the lookup table.
"""
cytosol_count = defaultdict(int)

for cytosol in self._nucleus_lookup_dict.values():
cytosol_count[cytosol] += 1

self.cytosol_count = cytosol_count
def check_for_unassigned_cytosols(self):

def _check_for_unassigned_cytosols(self):
"""
Check for unassigned cytosol IDs in the cytosol mask.
"""
Expand All @@ -660,8 +698,8 @@ def check_for_unassigned_cytosols(self):
for cytosol_id in all_cytosol_ids:
if cytosol_id not in self._nucleus_lookup_dict.values():
self.cytosol_discard_list.append(cytosol_id)
def identify_multinucleated_cells(self):

def _identify_multinucleated_cells(self):
"""
Identify and discard multinucleated cells from the lookup table.
"""
Expand All @@ -670,14 +708,14 @@ def identify_multinucleated_cells(self):
self.nuclei_discard_list.append(nucleus)
self.cytosol_discard_list.append(cytosol)

def cleanup_filtering_lists(self):
def _cleanup_filtering_lists(self):
"""
Cleanup the discard lists by removing duplicate entries.
"""
self.nuclei_discard_list = list(set(self.nuclei_discard_list))
self.cytosol_discard_list = list(set(self.cytosol_discard_list))

def cleanup_lookup_dictionary(self):
def _cleanup_lookup_dictionary(self):
"""
Cleanup the lookup dictionary by removing discarded nucleus-cytosol pairs.
"""
Expand All @@ -687,13 +725,13 @@ def cleanup_lookup_dictionary(self):
_cleanup.append(nucleus_id)
if cytosol_id in self.cytosol_discard_list:
_cleanup.append(nucleus_id)
#ensure we have no duplicate entries

# ensure we have no duplicate entries
_cleanup = list(set(_cleanup))
for nucleus in _cleanup:
del self._nucleus_lookup_dict[nucleus]

def generate_lookup_table(self, nucleus_mask, cytosol_mask):
def get_lookup_table(self, nucleus_mask, cytosol_mask):
"""
Generate the lookup table by performing all necessary steps.
Expand All @@ -710,12 +748,12 @@ def generate_lookup_table(self, nucleus_mask, cytosol_mask):
The lookup table mapping nucleus IDs to matched cytosol IDs.
"""
self.load_masks(nucleus_mask, cytosol_mask)
self.initialize_lookup_table()
self.count_cytosol_occurances()
self.check_for_unassigned_cytosols()
self.identify_multinucleated_cells()
self.cleanup_filtering_lists()
self.cleanup_lookup_dictionary()
self._initialize_lookup_table()
self._count_cytosol_occurances()
self._check_for_unassigned_cytosols()
self._identify_multinucleated_cells()
self._cleanup_filtering_lists()
self._cleanup_lookup_dictionary()

self.nucleus_lookup_dict = self._nucleus_lookup_dict

Expand All @@ -737,9 +775,9 @@ def filter(self, nucleus_mask, cytosol_mask):
tuple
A tuple containing the updated nucleus mask and cytosol mask after filtering.
"""
self.load_masks(nucleus_mask, cytosol_mask)
self._load_masks(nucleus_mask, cytosol_mask)

if self.nucleus_lookup_dict is None:
self.generate_lookup_table(nucleus_mask, cytosol_mask)
return self.update_masks()
self.get_lookup_table(nucleus_mask, cytosol_mask)

return self._get_updated_masks()

0 comments on commit 0b62d8a

Please sign in to comment.