Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move code to match nucleus and cytosol ids to own filtering class #2

Merged
merged 20 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
69fc104
add filter to match cytosol and nucleus ids
sophiamaedler Jun 3, 2024
396e97c
update segmentation workflow to utilize prewritten filter
sophiamaedler Jun 3, 2024
8c6aa02
update filtering workflow to utilize seperate filtering function
sophiamaedler Jun 3, 2024
7ad4093
fix inconsistent variable naming
sophiamaedler Jun 3, 2024
3da72f3
ensure cytosol mask is updated with the correct function
sophiamaedler Jun 3, 2024
f23b9a0
add docstrings to filtering method
sophiamaedler Jun 3, 2024
4edefad
remove potential duplicates from cleanup function before processing d…
sophiamaedler Jun 3, 2024
218d936
add log messages tracking how many cells were removed through filtering
sophiamaedler Jun 3, 2024
3d632b3
merge "main" into branch code_refactoring_filtering
sophiamaedler Jun 4, 2024
dfa5d40
remove unused variable
sophiamaedler Jun 4, 2024
ecf30b9
add function to get downsampling parameters
sophiamaedler Jun 4, 2024
643960d
implement sanity check to catch cases #4
sophiamaedler Jun 4, 2024
52e0d0a
Merge pull request #5 from MannLabs:fix_issue#4
sophiamaedler Jun 4, 2024
d0c0ef6
add missing part of previous commit
sophiamaedler Jun 4, 2024
623ef03
relocate function to calculate downsampling parameters
sophiamaedler Jun 4, 2024
0c09096
fix function call
sophiamaedler Jun 4, 2024
d293909
fix small bug in incorrect calculationg of number of classes
sophiamaedler Jun 4, 2024
b00b7fc
fix typo
sophiamaedler Jun 5, 2024
28346c4
add additional log message
sophiamaedler Jun 5, 2024
0b62d8a
rename functions more consistently
sophiamaedler Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 20 additions & 91 deletions src/sparcscore/pipeline/filtering_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
)

import numpy as np
from tqdm.auto import tqdm
import shutil
from collections import defaultdict

from sparcscore.processing.preprocessing import downsample_img_pxs
from sparcscore.processing.filtering import MatchNucleusCytosolIds


class BaseFiltering(SegmentationFilter):
def __init__(self, *args, **kwargs):
Expand All @@ -26,105 +24,36 @@ class filtering_match_nucleus_to_cytosol(BaseFiltering):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def match_nucleus_id_to_cytosol(
self, nucleus_mask, cytosol_mask, return_ids_to_discard=False
):
all_nucleus_ids = self.get_unique_ids(nucleus_mask)
all_cytosol_ids = self.get_unique_ids(cytosol_mask)

nucleus_cytosol_pairs = {}
nuclei_ids_to_discard = []

for nucleus_id in tqdm(all_nucleus_ids):
# get the nucleus and set the background to 0 and the nucleus to 1
nucleus = nucleus_mask == nucleus_id

# now get the coordinates of the nucleus
nucleus_pixels = np.nonzero(nucleus)

# check if those indices are not background in the cytosol mask
potential_cytosol = cytosol_mask[nucleus_pixels]

# if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus
if np.all(potential_cytosol != 0):
unique_cytosol, counts = np.unique(
potential_cytosol, return_counts=True
)
all_counts = np.sum(counts)
cytosol_proportions = counts / all_counts

if np.any(cytosol_proportions >= self.config["filtering_threshold"]):
# get the cytosol_id with max proportion
cytosol_id = unique_cytosol[
np.argmax(
cytosol_proportions >= self.config["filtering_threshold"]
)
]
nucleus_cytosol_pairs[nucleus_id] = cytosol_id
else:
# no cytosol found with sufficient quality to call so discard nucleus
nuclei_ids_to_discard.append(nucleus_id)

else:
# discard nucleus as no matching cytosol found
nuclei_ids_to_discard.append(nucleus_id)

# check to ensure that only one nucleus_id is assigned to each cytosol_id
cytosol_count = defaultdict(int)

# Count the occurrences of each cytosol value
for cytosol in nucleus_cytosol_pairs.values():
cytosol_count[cytosol] += 1

# Find cytosol values assigned to more than one nucleus and remove from dictionary
multi_nucleated_nulceus_ids = []
self.filter_threshold = self.config["filter_threshold"]

for nucleus, cytosol in nucleus_cytosol_pairs.items():
if cytosol_count[cytosol] > 1:
multi_nucleated_nulceus_ids.append(nucleus)

# update list of all nuclei used
nuclei_ids_to_discard.append(multi_nucleated_nulceus_ids)

# remove entries from dictionary
# this needs to be put into a seperate loop because otherwise the dictionary size changes during loop and this throws an error
for nucleus in multi_nucleated_nulceus_ids:
del nucleus_cytosol_pairs[nucleus]

# get all cytosol_ids that need to be discarded
used_cytosol_ids = set(nucleus_cytosol_pairs.values())
not_used_cytosol_ids = set(all_cytosol_ids) - used_cytosol_ids
not_used_cytosol_ids = list(not_used_cytosol_ids)

if return_ids_to_discard:
return (nucleus_cytosol_pairs, nuclei_ids_to_discard, not_used_cytosol_ids)
# allow for optional downsampling to improve computation time
if "downsampling_factor" in self.config.keys():
self.N = self.config["downsampling_factor"]
self.kernel_size = self.config["downsampling_smoothing_kernel_size"]
self.erosion_dilation = self.config["downsampling_erosion_dilation"]
else:
return nucleus_cytosol_pairs
self.N = None
self.kernel_size = None
self.erosion_dilation = False

def process(self, input_masks):
if isinstance(input_masks, str):
input_masks = self.read_input_masks(input_masks)

# allow for optional downsampling to improve computation time
if "downsampling_factor" in self.config.keys():
N = self.config["downsampling_factor"]
# use a less precise but faster downsampling method that preserves integer values
input_masks = downsample_img_pxs(input_masks, N=N)

# get input masks
nucleus_mask = input_masks[0, :, :]
cytosol_mask = input_masks[1, :, :]

nucleus_cytosol_pairs = self.match_nucleus_id_to_cytosol(
nucleus_mask, cytosol_mask
# perform filtering
filter = MatchNucleusCytosolIds(
filter_config=self.filter_threshold,
downsample_factor=self.N,
smoothing_kernel_size=self.kernel_size,
erosion_dilation=self.erosion_dilation,
)
nucleus_cytosol_pairs = filter.generate_lookup_table(
input_masks[0], input_masks[1]
)

# save results
self.save_classes(classes=nucleus_cytosol_pairs)

# cleanup TEMP directories if not done during individual tile runs
if hasattr(self, "TEMP_DIR_NAME"):
shutil.rmtree(self.TEMP_DIR_NAME)

class multithreaded_filtering_match_nucleus_to_cytosol(TiledSegmentationFilter):
method = filtering_match_nucleus_to_cytosol
7 changes: 5 additions & 2 deletions src/sparcscore/pipeline/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def save_segmentation_zarr(self, labels=None):
loc = parse_url(path, mode="w").store
group = zarr.group(store=loc)

segmentation_names = ["nucleus", "cyotosol"]
segmentation_names = ["nucleus", "cytosol"]

# check if segmentation names already exist if so delete
for seg_names in segmentation_names:
Expand Down Expand Up @@ -657,6 +657,7 @@ def calculate_sharding_plan(self, image_size):
def cleanup_shards(self, sharding_plan):
file_identifiers_plots = [".png", ".tif", ".tiff", ".jpg", ".jpeg", ".pdf"]

self.log("Moving generated plots from shard directory to main directory.")
for i, window in enumerate(sharding_plan):
local_shard_directory = os.path.join(self.shard_directory, str(i))
for file in os.listdir(local_shard_directory):
Expand Down Expand Up @@ -945,7 +946,6 @@ def process(self, input_image):

def complete_segmentation(self, input_image):
self.save_zarr = False
self.save_input_image(input_image)
self.shard_directory = os.path.join(self.directory, self.DEFAULT_SHARD_FOLDER)

# check to make sure that the shard directory exisits, if not exit and return error
Expand All @@ -954,6 +954,9 @@ def complete_segmentation(self, input_image):
"No Shard Directory found for the given project. Can not complete a segmentation which has not started. Please rerun the segmentation method."
)

#save input image to segmentation.h5
self.save_input_image(input_image)

# check to see which tiles are incomplete
tile_directories = os.listdir(self.shard_directory)
incomplete_indexes = []
Expand Down
Loading
Loading