Skip to content

Commit

Permalink
Merge pull request #17 from HelmholtzAI-Consultants-Munich/main-review
Browse files Browse the repository at this point in the history
* Code review
* Changed model weights & download option in GUI
* Made editable text boxes along with sliders
* Fixed bug that allows to enter new labels layer in back-end
  • Loading branch information
christinab12 authored Jul 17, 2023
2 parents 896ac3a + 6b21196 commit 224eff9
Show file tree
Hide file tree
Showing 8 changed files with 452 additions and 194 deletions.
5 changes: 4 additions & 1 deletion .napari/DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ You can install `napari-organoid-counter` via [pip](https://pypi.org/project/nap
To install latest development version :

pip install git+https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter.git


For installing on a Windows machine via napari, follow the instuctions [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/How%20to%20install%20on%20a%20Windows%20machine.pdf).

## Quickstart

After installation, you can start napari (either by typing ```napari``` in your terminal or by launching the application) and select the plugin from the drop down menu.

The use of the napari-organoid-counter plugin is straightforward. Here is a step-to-step guide of the standard workflow:
1. You can load the image or images you wish to process into the napari viewer with a simple drag and drop
2. You can then select the layer you wish to work on by the drop-down box at the top of the input configurations
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ For the dev branch you can clone this repo and install with:

pip install -e .

Then run napari on your terminal.

For installing on a Windows machine via napari, follow the instuctions [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/How%20to%20install%20on%20a%20Windows%20machine.pdf).

## What's new in v2?
Checkout our *What's New in v2* [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/.napari/DESCRIPTION.md#whats-new-in-v2).

## How to use?
After installing, you can start napari (either by typing ```napari``` in your terminal or by launching the application) and select the plugin from the drop down menu.

For more information on this plugin, its' intended audience, as well as Quickstart guide go to our [Quickstart guide](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/.napari/DESCRIPTION.md#quickstart).

## Contributing
Expand Down
185 changes: 153 additions & 32 deletions napari_organoid_counter/_orgacount.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,124 @@
import os
import torch
from torchvision.transforms import ToTensor
from napari_organoid_counter._utils import frcnn, prepare_img, apply_nms, convert_boxes_to_napari_view, convert_boxes_from_napari_view, get_diams
import subprocess

from urllib.request import urlretrieve
from napari.utils import progress

from napari_organoid_counter._utils import *
from napari_organoid_counter import settings


class OrganoiDL():
def __init__(self,
img_scale,
model_checkpoint=None):
'''
The back-end of the organoid counter widget
Attributes
----------
device: torch.device
The current device, either 'cpu' or 'gpu:0'
cur_confidence: float
The confidence threshold of the model
cur_min_diam: float
The minimum diameter of the organoids
transfroms: torchvision.transforms.ToTensor
The transformation for converting numpy image to tensor so it can be given as an input to the model
model: frcnn
The Faster R-CNN model
img_scale: list of floats
A list holding the image resolution in x and y
pred_bboxes: dict
Each key will be a set of predictions of the model, either past or current, and values will be the numpy arrays
holding the predicted bounding boxes
pred_scores: dict
Each key will be a set of predictions of the model and the values will hold the confidence of the model for each
predicted bounding box
pred_ids: dict
Each key will be a set of predictions of the model and the values will hold the box id for each
predicted bounding box
next_id: dict
Each key will be a set of predictions of the model and the values will hold the next id to be attributed to a
newly added box
'''
def __init__(self, handle_progress):
super().__init__()

self.handle_progress = handle_progress
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.cur_confidence = 0.05
self.cur_min_diam = 30
self.model = frcnn(num_classes=2, rpn_score_thresh=0, box_score_thresh = self.cur_confidence)
self.model_checkpoint = model_checkpoint
if not os.path.isfile(self.model_checkpoint):
self.download_model()
self.load_model_checkpoint(self.model_checkpoint)
self.model = self.model.to(self.device)
self.transfroms = ToTensor()

self.img_scale = img_scale
self.model = None
self.img_scale = [0., 0.]
self.pred_bboxes = {}
self.pred_scores = {}
self.pred_ids = {}
self.next_id = {}

def download_model(self):
subprocess.run(["zenodo_get","10.5281/zenodo.7708763","-o", "model"])
self.model_checkpoint = os.path.join(os.getcwd(), 'model', 'model_v1.ckpt')
#zenodo_get(['10.5281/zenodo.7708763'])
def set_scale(self, img_scale):
''' Set the image scale: used to calculate real box sizes. '''
self.img_scale = img_scale

def load_model_checkpoint(self, model_path):
ckpt = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(ckpt) #.state_dict())
def set_model(self, model_name):
''' Initialise model instance and load model checkpoint and send to device. '''
self.model = frcnn(num_classes=2, rpn_score_thresh=0, box_score_thresh = self.cur_confidence)
self.load_model_checkpoint(model_name)
self.model = self.model.to(self.device)

def sliding_window(self, test_img, step, window_size, rescale_factor, prepadded_height, prepadded_width, pred_bboxes=[], scores_list=[]):

#img_height, img_width = test_img.size(2), test_img.size(3)
def download_model(self, model='default'):
''' Downloads the model from zenodo and stores it in settings.MODELS_DIR '''
# specify the url of the file which is to be downloaded
down_url = settings.MODELS[model]["source"]
# specify save location where the file is to be saved
save_loc = join_paths(str(settings.MODELS_DIR), settings.MODELS[model]["filename"])
# Downloading using urllib
urlretrieve(down_url,save_loc, self.handle_progress)

for i in range(0, prepadded_height, step):
for j in range(0, prepadded_width, step):

def load_model_checkpoint(self, model_name):
''' Loads the model checkpoint for the model specified in model_name '''
model_checkpoint = join_paths(settings.MODELS_DIR, settings.MODELS[model_name]["filename"])
ckpt = torch.load(model_checkpoint, map_location=self.device)
self.model.load_state_dict(ckpt) #.state_dict())

def sliding_window(self,
test_img,
step,
window_size,
rescale_factor,
prepadded_height,
prepadded_width,
pred_bboxes=[],
scores_list=[]):
''' Runs sliding window inference and returns predicting bounding boxes and confidence scores for each box.
Inputs
----------
test_img: Tensor of size [B, C, H, W]
The image ready to be given to model as input
step: int
The step of the sliding window, same in x and y
window_size: int
The sliding window size, same in x and y
rescale_factor: float
The rescaling factor by which the image has already been resized. Is 1/downsampling
prepadded_height: int
The image height before padding was applied
prepadded_width: int
The image width before padding was applied
pred_bboxes: list of
The
scores_list: list of
The
Outputs
----------
pred_bboxes: list of Tensors, default is an empty list
The resulting predicted boxes are appended here - if model is run at different window
sizes and downsampling this list will store results of all runs of the sliding window
so will not be empty the second, third etc. time.
scores_list: list of Tensor, default is an empty list
The resulting confidence scores of the model for the predicted boxes are appended here
Same as pred_bboxes, can be empty on first run but stores results of all runs.
'''
for i in progress(range(0, prepadded_height, step)):
for j in progress(range(0, prepadded_width, step)):
# crop
img_crop = test_img[:, :, i:(i+window_size), j:(j+window_size)]
# get predictions
Expand All @@ -66,21 +142,47 @@ def run(self,
window_sizes,
downsampling_sizes,
window_overlap):

# run for all window sizes
''' Runs inference for an image at multiple window sizes and downsampling rates using sliding window ineference.
The results are filtered using the NMS algorithm and are then stored to dicts.
Inputs
----------
img: Numpy array of size [H, W]
The image ready to be given to model as input
shapes_name: str
The name of the new predictions
window_size: list of ints
The sliding window size, same in x and y, if multiple sliding window will run mulitple times
downsampling_sizes: list of ints
The downsampling factor of the image, list size must match window_size
window_overlap: float
The window overlap for the sliding window inference.
'''
bboxes = []
scores = []

# run for all window sizes
for window_size, downsampling in zip(window_sizes, downsampling_sizes):
# compute the step for the sliding window, based on window overlap
rescale_factor = 1 / downsampling
# window size after rescaling
window_size = round(window_size * rescale_factor)
step = round(window_size * window_overlap)
# prepare image for model - norm, tensor, etc.
ready_img, prepadded_height, prepadded_width = prepare_img(img, step, window_size, rescale_factor, self.transfroms , self.device)
bboxes, scores = self.sliding_window(ready_img, step, window_size, rescale_factor, prepadded_height, prepadded_width, bboxes, scores)

ready_img, prepadded_height, prepadded_width = prepare_img(img,
step,
window_size,
rescale_factor,
self.transfroms,
self.device)
# and run sliding window over whole image
bboxes, scores = self.sliding_window(ready_img,
step,
window_size,
rescale_factor,
prepadded_height,
prepadded_width,
bboxes,
scores)
# stack results
bboxes = torch.stack(bboxes)
scores = torch.stack(scores)
# apply NMS to remove overlaping boxes
Expand All @@ -92,6 +194,8 @@ def run(self,
self.next_id[shapes_name] = num_predictions+1

def apply_params(self, shapes_name, confidence, min_diameter_um):
""" After results have been stored in dict this function will filter the dicts based on the confidence
and min_diameter_um thresholds for the given results defined by shape_name and return the filtered dicts. """
self.cur_confidence = confidence
self.cur_min_diam = min_diameter_um
pred_bboxes, pred_scores, pred_ids = self._apply_confidence_thresh(shapes_name)
Expand All @@ -101,6 +205,7 @@ def apply_params(self, shapes_name, confidence, min_diameter_um):
return pred_bboxes, pred_scores, pred_ids

def _apply_confidence_thresh(self, shapes_name):
""" Filters out results of shapes_name based on the current confidence threshold. """
if shapes_name not in self.pred_bboxes.keys(): return torch.empty((0))
keep = (self.pred_scores[shapes_name]>self.cur_confidence).nonzero(as_tuple=True)[0]
result_bboxes = self.pred_bboxes[shapes_name][keep]
Expand All @@ -109,6 +214,7 @@ def _apply_confidence_thresh(self, shapes_name):
return result_bboxes, result_scores, result_ids

def _filter_small_organoids(self, pred_bboxes, pred_scores, pred_ids):
""" Filters out small result boxes of shapes_name based on the current min diameter size. """
if pred_bboxes is None: return None
if len(pred_bboxes)==0: return None
min_diameter_x = self.cur_min_diam / self.img_scale[0]
Expand All @@ -123,6 +229,11 @@ def _filter_small_organoids(self, pred_bboxes, pred_scores, pred_ids):
return pred_bboxes, pred_scores, pred_ids

def update_bboxes_scores(self, shapes_name, new_bboxes, new_scores, new_ids):
''' Updated the results dicts, self.pred_bboxes, self.pred_scores and self.pred_ids with new results.
If the shapes name doesn't exist as a key in the dicts the results are added with the new key. If the
key exists then new_bboxes, new_scores and new_ids are compared to the class result dicts and the dicts
are updated, either by adding some box (user added box) or removing some box (user deleted a prediction).'''

new_bboxes = convert_boxes_from_napari_view(new_bboxes)
new_scores = torch.Tensor(list(new_scores))
new_ids = list(new_ids)
Expand All @@ -131,7 +242,10 @@ def update_bboxes_scores(self, shapes_name, new_bboxes, new_scores, new_ids):
self.pred_bboxes[shapes_name] = new_bboxes
self.pred_scores[shapes_name] = new_scores
self.pred_ids[shapes_name] = new_ids
self.next_id[shapes_name] = len(new_ids)+1

elif len(new_ids)==0: return

else:
min_diameter_x = self.cur_min_diam / self.img_scale[0]
min_diameter_y = self.cur_min_diam / self.img_scale[1]
Expand Down Expand Up @@ -163,7 +277,14 @@ def update_bboxes_scores(self, shapes_name, new_bboxes, new_scores, new_ids):
self.pred_ids[shapes_name] = new_pred_ids

def update_next_id(self, shapes_name, c=0):
""" Updates the next id to append to result dicts. If input c is given then that will be the next id. """
if c!=0:
self.next_id[shapes_name] = c
else: self.next_id[shapes_name] += 1

def remove_shape_from_dict(self, shapes_name):
""" Removes results of shapes_name from all result dicts. """
del self.pred_bboxes[shapes_name]
del self.pred_scores[shapes_name]
del self.pred_ids[shapes_name]
del self.next_id[shapes_name]
8 changes: 6 additions & 2 deletions napari_organoid_counter/_tests/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def test_organoid_counter_widget(make_napari_viewer, capsys):
viewer.add_image(test_img, name='Test')

# create our widget, passing in the viewer
my_widget = OrganoidCounterWidget(viewer)
my_widget = OrganoidCounterWidget(viewer,
window_sizes=[500],
downsampling=[2],
window_overlap=1)

# call preprocessing - remove duplicate here?
my_widget._preprocess()
Expand All @@ -34,12 +37,13 @@ def test_organoid_counter_widget(make_napari_viewer, capsys):
assert np.min(img) >= 0.

# test that organoid counting algorithm has run and new layer with res has been added to viewer
my_widget.organoiDL.download_model(my_widget.model_name)
my_widget._on_run_click()
layer_names = [layer.name for layer in my_widget.viewer.layers]
assert 'Labels-Test' in layer_names

# test that class attributes are updated when user changes values from GUI
my_widget._on_diameter_changed()
my_widget._on_diameter_slider_changed()
assert my_widget.min_diameter == my_widget.min_diameter_slider.value()

my_widget._on_image_selection_changed()
Expand Down
33 changes: 33 additions & 0 deletions napari_organoid_counter/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from contextlib import contextmanager
import os
from pathlib import Path

import numpy as np
import math
import json
Expand All @@ -12,6 +15,36 @@
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import nms

from napari_organoid_counter import settings


def add_local_models():
""" Checks the models directory for any local models previously added by the user.
If some are found then these are added to the model dictionary (see settings). """
if not os.path.exists(settings.MODELS_DIR): return
model_names_in_dir = [file for file in os.listdir(settings.MODELS_DIR)]
model_names_in_dict = [settings.MODELS[key]["filename"] for key in settings.MODELS.keys()]
for model_name in model_names_in_dir:
if model_name not in model_names_in_dict and model_name.endswith(settings.MODEL_TYPE):
_ = add_to_dict(model_name)

def add_to_dict(filepath):
""" Given the full path and name of a model in filepath the model is added to the models dict (see settings)"""
filepath = Path(filepath)
name = filepath.name
stem_name = filepath.stem
settings.MODELS[stem_name] = {"filename": name, "source": "local"}
return stem_name

def return_is_file(path, filename):
""" Return True if the file exists in path and False otherwise """
full_path = join_paths(path, filename)
return os.path.isfile(full_path)

def join_paths(path1, path2):
""" Returns output of os.path.join """
return os.path.join(path1, path2)

@contextmanager
def set_dict_key(dictionary, key, value):
""" Used to set a new value in the napari layer metadata """
Expand Down
Loading

0 comments on commit 224eff9

Please sign in to comment.