Skip to content

Commit

Permalink
Merge pull request #2 from MannLabs/code-refator_filtering
Browse files Browse the repository at this point in the history
move code to match nucleus and cytosol ids to own filtering class
  • Loading branch information
sophiamaedler authored Jun 5, 2024
2 parents 32254d4 + 0b62d8a commit 7605640
Show file tree
Hide file tree
Showing 4 changed files with 545 additions and 446 deletions.
111 changes: 20 additions & 91 deletions src/sparcscore/pipeline/filtering_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
)

import numpy as np
from tqdm.auto import tqdm
import shutil
from collections import defaultdict

from sparcscore.processing.preprocessing import downsample_img_pxs
from sparcscore.processing.filtering import MatchNucleusCytosolIds


class BaseFiltering(SegmentationFilter):
def __init__(self, *args, **kwargs):
Expand All @@ -26,105 +24,36 @@ class filtering_match_nucleus_to_cytosol(BaseFiltering):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def match_nucleus_id_to_cytosol(
self, nucleus_mask, cytosol_mask, return_ids_to_discard=False
):
all_nucleus_ids = self.get_unique_ids(nucleus_mask)
all_cytosol_ids = self.get_unique_ids(cytosol_mask)

nucleus_cytosol_pairs = {}
nuclei_ids_to_discard = []

for nucleus_id in tqdm(all_nucleus_ids):
# get the nucleus and set the background to 0 and the nucleus to 1
nucleus = nucleus_mask == nucleus_id

# now get the coordinates of the nucleus
nucleus_pixels = np.nonzero(nucleus)

# check if those indices are not background in the cytosol mask
potential_cytosol = cytosol_mask[nucleus_pixels]

# if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus
if np.all(potential_cytosol != 0):
unique_cytosol, counts = np.unique(
potential_cytosol, return_counts=True
)
all_counts = np.sum(counts)
cytosol_proportions = counts / all_counts

if np.any(cytosol_proportions >= self.config["filtering_threshold"]):
# get the cytosol_id with max proportion
cytosol_id = unique_cytosol[
np.argmax(
cytosol_proportions >= self.config["filtering_threshold"]
)
]
nucleus_cytosol_pairs[nucleus_id] = cytosol_id
else:
# no cytosol found with sufficient quality to call so discard nucleus
nuclei_ids_to_discard.append(nucleus_id)

else:
# discard nucleus as no matching cytosol found
nuclei_ids_to_discard.append(nucleus_id)

# check to ensure that only one nucleus_id is assigned to each cytosol_id
cytosol_count = defaultdict(int)

# Count the occurrences of each cytosol value
for cytosol in nucleus_cytosol_pairs.values():
cytosol_count[cytosol] += 1

# Find cytosol values assigned to more than one nucleus and remove from dictionary
multi_nucleated_nulceus_ids = []
self.filter_threshold = self.config["filter_threshold"]

for nucleus, cytosol in nucleus_cytosol_pairs.items():
if cytosol_count[cytosol] > 1:
multi_nucleated_nulceus_ids.append(nucleus)

# update list of all nuclei used
nuclei_ids_to_discard.append(multi_nucleated_nulceus_ids)

# remove entries from dictionary
# this needs to be put into a seperate loop because otherwise the dictionary size changes during loop and this throws an error
for nucleus in multi_nucleated_nulceus_ids:
del nucleus_cytosol_pairs[nucleus]

# get all cytosol_ids that need to be discarded
used_cytosol_ids = set(nucleus_cytosol_pairs.values())
not_used_cytosol_ids = set(all_cytosol_ids) - used_cytosol_ids
not_used_cytosol_ids = list(not_used_cytosol_ids)

if return_ids_to_discard:
return (nucleus_cytosol_pairs, nuclei_ids_to_discard, not_used_cytosol_ids)
# allow for optional downsampling to improve computation time
if "downsampling_factor" in self.config.keys():
self.N = self.config["downsampling_factor"]
self.kernel_size = self.config["downsampling_smoothing_kernel_size"]
self.erosion_dilation = self.config["downsampling_erosion_dilation"]
else:
return nucleus_cytosol_pairs
self.N = None
self.kernel_size = None
self.erosion_dilation = False

def process(self, input_masks):
if isinstance(input_masks, str):
input_masks = self.read_input_masks(input_masks)

# allow for optional downsampling to improve computation time
if "downsampling_factor" in self.config.keys():
N = self.config["downsampling_factor"]
# use a less precise but faster downsampling method that preserves integer values
input_masks = downsample_img_pxs(input_masks, N=N)

# get input masks
nucleus_mask = input_masks[0, :, :]
cytosol_mask = input_masks[1, :, :]

nucleus_cytosol_pairs = self.match_nucleus_id_to_cytosol(
nucleus_mask, cytosol_mask
# perform filtering
filter = MatchNucleusCytosolIds(
filter_config=self.filter_threshold,
downsample_factor=self.N,
smoothing_kernel_size=self.kernel_size,
erosion_dilation=self.erosion_dilation,
)
nucleus_cytosol_pairs = filter.generate_lookup_table(
input_masks[0], input_masks[1]
)

# save results
self.save_classes(classes=nucleus_cytosol_pairs)

# cleanup TEMP directories if not done during individual tile runs
if hasattr(self, "TEMP_DIR_NAME"):
shutil.rmtree(self.TEMP_DIR_NAME)

class multithreaded_filtering_match_nucleus_to_cytosol(TiledSegmentationFilter):
method = filtering_match_nucleus_to_cytosol
7 changes: 5 additions & 2 deletions src/sparcscore/pipeline/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def save_segmentation_zarr(self, labels=None):
loc = parse_url(path, mode="w").store
group = zarr.group(store=loc)

segmentation_names = ["nucleus", "cyotosol"]
segmentation_names = ["nucleus", "cytosol"]

# check if segmentation names already exist if so delete
for seg_names in segmentation_names:
Expand Down Expand Up @@ -657,6 +657,7 @@ def calculate_sharding_plan(self, image_size):
def cleanup_shards(self, sharding_plan):
file_identifiers_plots = [".png", ".tif", ".tiff", ".jpg", ".jpeg", ".pdf"]

self.log("Moving generated plots from shard directory to main directory.")
for i, window in enumerate(sharding_plan):
local_shard_directory = os.path.join(self.shard_directory, str(i))
for file in os.listdir(local_shard_directory):
Expand Down Expand Up @@ -945,7 +946,6 @@ def process(self, input_image):

def complete_segmentation(self, input_image):
self.save_zarr = False
self.save_input_image(input_image)
self.shard_directory = os.path.join(self.directory, self.DEFAULT_SHARD_FOLDER)

# check to make sure that the shard directory exisits, if not exit and return error
Expand All @@ -954,6 +954,9 @@ def complete_segmentation(self, input_image):
"No Shard Directory found for the given project. Can not complete a segmentation which has not started. Please rerun the segmentation method."
)

#save input image to segmentation.h5
self.save_input_image(input_image)

# check to see which tiles are incomplete
tile_directories = os.listdir(self.shard_directory)
incomplete_indexes = []
Expand Down
Loading

0 comments on commit 7605640

Please sign in to comment.