diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml
index 08893b8..9546c84 100644
--- a/.github/workflows/test_and_deploy.yml
+++ b/.github/workflows/test_and_deploy.yml
@@ -49,6 +49,23 @@ jobs:
with:
qt: true
+ # cache atlases needed by the tests
+ - name: Cache Atlases
+ id: atlas-cache
+ uses: actions/cache@v3
+ with:
+ path: | # ensure we don't cache any interrupted atlas download and extraction!
+ ~/.brainglobe/*
+ !~/.brainglobe/atlas.tar.gz
+ key: ${{ runner.os }}-cached-atlases
+ enableCrossOsArchive: false # ~ and $HOME evaluate to different places across OSs!
+
+ - if: ${{ steps.atlas-cache.outputs.cache-hit == 'true' }}
+ name: List files in brainglobe data folder # good to be able to sanity check that user data is as expected
+ run: |
+ ls -af ~/.brainglobe/
+
+
# Run tests
- uses: neuroinformatics-unit/actions/test@v2
with:
diff --git a/.gitignore b/.gitignore
index 73d56d3..08e97e3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -82,3 +82,6 @@ venv/
# written by setuptools_scm
**/_version.py
+
+# Elastix related files
+/Logs/
diff --git a/brainglobe_registration/elastix/register.py b/brainglobe_registration/elastix/register.py
index d04d918..4a216f9 100644
--- a/brainglobe_registration/elastix/register.py
+++ b/brainglobe_registration/elastix/register.py
@@ -1,7 +1,8 @@
-from typing import List
+from typing import List, Tuple
import itk
import numpy as np
+import numpy.typing as npt
from brainglobe_atlasapi import BrainGlobeAtlas
@@ -28,28 +29,30 @@ def run_registration(
atlas_image,
moving_image,
annotation_image,
- parameter_lists: List[tuple[str, dict]],
-) -> tuple[np.ndarray, itk.ParameterObject, np.ndarray]:
+ parameter_lists: List[Tuple[str, dict]],
+) -> Tuple[npt.NDArray, itk.ParameterObject, npt.NDArray]:
"""
Run the registration process on the given images.
Parameters
----------
- atlas_image : np.ndarray
+ atlas_image : npt.NDArray
The atlas image.
- moving_image : np.ndarray
+ moving_image : npt.NDArray
The moving image.
- annotation_image : np.ndarray
+ annotation_image : npt.NDArray
The annotation image.
parameter_lists : List[tuple[str, dict]], optional
The list of parameter lists, by default None
Returns
-------
- np.ndarray
+ npt.NDArray
The result image.
itk.ParameterObject
The result transform parameters.
+ npt.NDArray
+ The transformed annotation image.
"""
# convert to ITK, view only
atlas_image = itk.GetImageViewFromArray(atlas_image).astype(itk.F)
diff --git a/brainglobe_registration/registration_widget.py b/brainglobe_registration/registration_widget.py
index d389d0c..4d0f367 100644
--- a/brainglobe_registration/registration_widget.py
+++ b/brainglobe_registration/registration_widget.py
@@ -9,25 +9,32 @@
"""
from pathlib import Path
+from typing import Optional
+import dask.array as da
+import napari.layers
import numpy as np
+import numpy.typing as npt
from brainglobe_atlasapi import BrainGlobeAtlas
from brainglobe_atlasapi.list_atlases import get_downloaded_atlases
+from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
+from brainglobe_utils.qtpy.dialog import display_info
from brainglobe_utils.qtpy.logo import header_widget
+from dask_image.ndinterp import affine_transform as dask_affine_transform
+from napari.qt.threading import thread_worker
+from napari.utils.events import Event
+from napari.utils.notifications import show_error
from napari.viewer import Viewer
-from qtpy.QtCore import Qt
-from qtpy.QtWidgets import (
- QGroupBox,
- QPushButton,
- QTabWidget,
- QVBoxLayout,
- QWidget,
-)
+from pytransform3d.rotations import active_matrix_from_angle
+from qtpy.QtWidgets import QCheckBox, QPushButton, QTabWidget
from skimage.segmentation import find_boundaries
+from skimage.transform import rescale
-from brainglobe_registration.elastix.register import run_registration
from brainglobe_registration.utils.utils import (
adjust_napari_image_layer,
+ calculate_rotated_bounding_box,
+ check_atlas_installed,
+ filter_plane,
find_layer_index,
get_image_layer_names,
open_parameter_file,
@@ -44,13 +51,19 @@
)
-class RegistrationWidget(QWidget):
+class RegistrationWidget(CollapsibleWidgetContainer):
def __init__(self, napari_viewer: Viewer):
super().__init__()
+ self.setContentsMargins(10, 10, 10, 10)
self._viewer = napari_viewer
- self._atlas: BrainGlobeAtlas = None
- self._moving_image = None
+ self._atlas: Optional[BrainGlobeAtlas] = None
+ self._atlas_data_layer: Optional[napari.layers.Image] = None
+ self._atlas_annotations_layer: Optional[napari.layers.Labels] = None
+ self._moving_image: Optional[napari.layers.Image] = None
+ self._moving_image_data_backup: Optional[npt.NDArray] = None
+ self._automatic_deletion_flag = False
+ # Flag to differentiate between manual and automatic atlas deletion
self.transform_params: dict[str, dict] = {
"rigid": {},
@@ -84,22 +97,6 @@ def __init__(self, napari_viewer: Viewer):
else:
self._moving_image = None
- self.setLayout(QVBoxLayout())
- self.layout().addWidget(
- header_widget(
- "brainglobe-
registration", # line break at
- "Registration with Elastix",
- github_repo_name="brainglobe-registration",
- )
- )
-
- self.main_tabs = QTabWidget(parent=self)
- self.main_tabs.setTabPosition(QTabWidget.West)
-
- self.settings_tab = QGroupBox()
- self.settings_tab.setLayout(QVBoxLayout())
- self.parameters_tab = QTabWidget()
-
self.get_atlas_widget = SelectImagesView(
available_atlases=self._available_atlases,
sample_image_names=self._sample_images,
@@ -119,10 +116,18 @@ def __init__(self, napari_viewer: Viewer):
self.adjust_moving_image_widget.adjust_image_signal.connect(
self._on_adjust_moving_image
)
-
self.adjust_moving_image_widget.reset_image_signal.connect(
self._on_adjust_moving_image_reset_button_click
)
+ self.adjust_moving_image_widget.scale_image_signal.connect(
+ self._on_scale_moving_image
+ )
+ self.adjust_moving_image_widget.atlas_rotation_signal.connect(
+ self._on_adjust_atlas_rotation
+ )
+ self.adjust_moving_image_widget.reset_atlas_signal.connect(
+ self._on_atlas_reset
+ )
self.transform_select_view = TransformSelectView()
self.transform_select_view.transform_type_added_signal.connect(
@@ -135,17 +140,33 @@ def __init__(self, napari_viewer: Viewer):
self._on_default_file_selection_change
)
+ # Use decorator to connect to layer deletion event
+ self._connect_events()
+ self.filter_checkbox = QCheckBox("Filter Images")
+ self.filter_checkbox.setChecked(False)
+
self.run_button = QPushButton("Run")
self.run_button.clicked.connect(self._on_run_button_click)
self.run_button.setEnabled(False)
- self.settings_tab.layout().addWidget(self.get_atlas_widget)
- self.settings_tab.layout().addWidget(self.adjust_moving_image_widget)
- self.settings_tab.layout().addWidget(self.transform_select_view)
- self.settings_tab.layout().addWidget(self.run_button)
- self.settings_tab.layout().setAlignment(Qt.AlignTop)
+ self.add_widget(
+ header_widget(
+ "brainglobe-
registration", # line break at
+ "Registration with Elastix",
+ github_repo_name="brainglobe-registration",
+ ),
+ collapsible=False,
+ )
+ self.add_widget(self.get_atlas_widget, widget_title="Select Images")
+ self.add_widget(
+ self.adjust_moving_image_widget, widget_title="Prepare Images"
+ )
+ self.add_widget(
+ self.transform_select_view, widget_title="Select Transformations"
+ )
self.parameter_setting_tabs_lists = []
+ self.parameters_tab = QTabWidget(parent=self)
for transform_type in self.transform_params:
new_tab = RegistrationParameterListView(
@@ -156,67 +177,188 @@ def __init__(self, napari_viewer: Viewer):
self.parameters_tab.addTab(new_tab, transform_type)
self.parameter_setting_tabs_lists.append(new_tab)
- self.main_tabs.addTab(self.settings_tab, "Settings")
- self.main_tabs.addTab(self.parameters_tab, "Parameters")
+ self.add_widget(
+ self.parameters_tab, widget_title="Advanced Settings (optional)"
+ )
+
+ self.add_widget(self.filter_checkbox, collapsible=False)
+
+ self.add_widget(self.run_button, collapsible=False)
- self.layout().addWidget(self.main_tabs)
+ self.layout().itemAt(1).widget().collapse(animate=False)
+
+ check_atlas_installed(self)
+
+ def _connect_events(self):
+ @self._viewer.layers.events.removed.connect
+ def _on_layer_deleted(event: Event):
+ if not self._automatic_deletion_flag:
+ self._handle_layer_deletion(event)
+
+ def _handle_layer_deletion(self, event: Event):
+ deleted_layer = event.value
+
+ # Check if the deleted layer is the moving image
+ if self._moving_image == deleted_layer:
+ self._moving_image = None
+ self._moving_image_data_backup = None
+ self._update_dropdowns()
+
+ # Check if deleted layer is the atlas reference / atlas annotations
+ if (
+ self._atlas_data_layer == deleted_layer
+ or self._atlas_annotations_layer == deleted_layer
+ ):
+ # Reset the atlas selection combobox
+ self.get_atlas_widget.reset_atlas_combobox()
+
+ def _delete_atlas_layers(self):
+ # Delete atlas reference layer if it exists
+ if self._atlas_data_layer in self._viewer.layers:
+ self._viewer.layers.remove(self._atlas_data_layer)
+
+ # Delete atlas annotations layer if it exists
+ if self._atlas_annotations_layer in self._viewer.layers:
+ self._viewer.layers.remove(self._atlas_annotations_layer)
+
+ # Clear atlas attributes
+ self._atlas = None
+ self._atlas_data_layer = None
+ self._atlas_annotations_layer = None
+ self.run_button.setEnabled(False)
+ self._viewer.grid.enabled = False
+
+ def _update_dropdowns(self):
+ # Extract the names of the remaining layers
+ layer_names = get_image_layer_names(self._viewer)
+ # Update the dropdowns in SelectImagesView
+ self.get_atlas_widget.update_sample_image_names(layer_names)
def _on_atlas_dropdown_index_changed(self, index):
# Hacky way of having an empty first dropdown
if index == 0:
- if self._atlas:
- curr_atlas_layer_index = find_layer_index(
- self._viewer, self._atlas.atlas_name
- )
-
- self._viewer.layers.pop(curr_atlas_layer_index)
- self._atlas = None
- self.run_button.setEnabled(False)
- self._viewer.grid.enabled = False
-
+ self._delete_atlas_layers()
return
atlas_name = self._available_atlases[index]
- atlas = BrainGlobeAtlas(atlas_name)
if self._atlas:
- curr_atlas_layer_index = find_layer_index(
- self._viewer, self._atlas.atlas_name
- )
-
- self._viewer.layers.pop(curr_atlas_layer_index)
- else:
- self.run_button.setEnabled(True)
+ # Set flag to True when atlas is changed, not manually deleted
+ # Ensures atlas index does not reset to 0
+ self._automatic_deletion_flag = True
+ try:
+ self._delete_atlas_layers()
+ finally:
+ self._automatic_deletion_flag = False
+
+ self.run_button.setEnabled(True)
+
+ self._atlas = BrainGlobeAtlas(atlas_name)
+ dask_reference = da.from_array(
+ self._atlas.reference,
+ chunks=(
+ 1,
+ self._atlas.reference.shape[1],
+ self._atlas.reference.shape[2],
+ ),
+ )
+ dask_annotations = da.from_array(
+ self._atlas.annotation,
+ chunks=(
+ 1,
+ self._atlas.annotation.shape[1],
+ self._atlas.annotation.shape[2],
+ ),
+ )
- self._viewer.add_image(
- atlas.reference,
+ contrast_max = np.max(
+ dask_reference[dask_reference.shape[0] // 2]
+ ).compute()
+ self._atlas_data_layer = self._viewer.add_image(
+ dask_reference,
name=atlas_name,
colormap="gray",
blending="translucent",
+ contrast_limits=[0, contrast_max],
+ multiscale=False,
+ )
+ self._atlas_annotations_layer = self._viewer.add_labels(
+ dask_annotations,
+ name="Annotations",
+ visible=False,
)
- self._atlas = BrainGlobeAtlas(atlas_name=atlas_name)
self._viewer.grid.enabled = True
def _on_sample_dropdown_index_changed(self, index):
viewer_index = find_layer_index(
self._viewer, self._sample_images[index]
)
+ if viewer_index == -1:
+ self._moving_image = None
+ self._moving_image_data_backup = None
+
+ return
+
self._moving_image = self._viewer.layers[viewer_index]
+ self._moving_image_data_backup = self._moving_image.data.copy()
def _on_adjust_moving_image(self, x: int, y: int, rotate: float):
+ if not self._moving_image:
+ show_error(
+ "No moving image selected."
+ "Please select a moving image before adjusting"
+ )
+ return
adjust_napari_image_layer(self._moving_image, x, y, rotate)
def _on_adjust_moving_image_reset_button_click(self):
+ if not self._moving_image:
+ show_error(
+ "No moving image selected. "
+ "Please select a moving image before adjusting"
+ )
+ return
adjust_napari_image_layer(self._moving_image, 0, 0, 0)
def _on_run_button_click(self):
+ from brainglobe_registration.elastix.register import run_registration
+
+ if self._atlas_data_layer is None:
+ display_info(
+ widget=self,
+ title="Warning",
+ message="Please select an atlas before clicking 'Run'.",
+ )
+ return
+
+ if self._moving_image == self._atlas_data_layer:
+ display_info(
+ widget=self,
+ title="Warning",
+ message="Your moving image cannot be an atlas.",
+ )
+ return
+
current_atlas_slice = self._viewer.dims.current_step[0]
+ if self.filter_checkbox.isChecked():
+ atlas_layer = filter_plane(
+ self._atlas_data_layer.data[
+ current_atlas_slice, :, :
+ ].compute()
+ )
+ moving_image_layer = filter_plane(self._moving_image.data)
+ else:
+ atlas_layer = self._atlas_data_layer.data[
+ current_atlas_slice, :, :
+ ]
+ moving_image_layer = self._moving_image.data
+
result, parameters, registered_annotation_image = run_registration(
- self._atlas.reference[current_atlas_slice, :, :],
- self._moving_image.data,
- self._atlas.annotation[current_atlas_slice, :, :],
+ atlas_layer,
+ moving_image_layer,
+ self._atlas_annotations_layer.data[current_atlas_slice, :, :],
self.transform_selections,
)
@@ -313,3 +455,133 @@ def _on_default_file_selection_change(
def _on_sample_popup_about_to_show(self):
self._sample_images = get_image_layer_names(self._viewer)
self.get_atlas_widget.update_sample_image_names(self._sample_images)
+
+ def _on_scale_moving_image(self, x: float, y: float):
+ """
+ Scale the moving image to have resolution equal to the atlas.
+
+ Parameters
+ ------------
+ x : float
+ Moving image x pixel size (> 0.0).
+ y : float
+ Moving image y pixel size (> 0.0).
+
+ Will show an error if the pixel sizes are less than or equal to 0.
+ Will show an error if the moving image or atlas is not selected.
+ """
+ if x <= 0 or y <= 0:
+ show_error("Pixel sizes must be greater than 0")
+ return
+
+ if not (self._moving_image and self._atlas):
+ show_error(
+ "Sample image or atlas not selected. "
+ "Please select a sample image and atlas before scaling",
+ )
+ return
+
+ if self._moving_image_data_backup is None:
+ self._moving_image_data_backup = self._moving_image.data.copy()
+
+ x_factor = x / self._atlas.resolution[0]
+ y_factor = y / self._atlas.resolution[1]
+
+ self._moving_image.data = rescale(
+ self._moving_image_data_backup,
+ (y_factor, x_factor),
+ mode="constant",
+ preserve_range=True,
+ anti_aliasing=True,
+ )
+
+ def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float):
+ if not (
+ self._atlas
+ and self._atlas_data_layer
+ and self._atlas_annotations_layer
+ ):
+ show_error(
+ "No atlas selected. Please select an atlas before rotating"
+ )
+ return
+
+ # Create the rotation matrix
+ roll_matrix = active_matrix_from_angle(0, np.deg2rad(roll))
+ pitch_matrix = active_matrix_from_angle(2, np.deg2rad(pitch))
+ yaw_matrix = active_matrix_from_angle(1, np.deg2rad(yaw))
+
+ rot_matrix = roll_matrix @ pitch_matrix @ yaw_matrix
+
+ full_matrix = np.eye(4)
+ full_matrix[:-1, :-1] = rot_matrix
+
+ # Translate the origin to the center of the image
+ origin = np.asarray(self._atlas.reference.shape) // 2
+ translate_matrix = np.eye(4)
+ translate_matrix[:-1, -1] = origin
+
+ bounding_box = calculate_rotated_bounding_box(
+ self._atlas.reference.shape, full_matrix
+ )
+ new_translation = np.asarray(bounding_box) // 2
+ post_rotate_translation = np.eye(4)
+ post_rotate_translation[:3, -1] = -new_translation
+
+ # Combine the matrices. The order of operations is:
+ # 1. Translate the origin to the center of the image
+ # 2. Rotate the image
+ # 3. Translate the origin back to the top left corner
+ transform_matrix = (
+ translate_matrix @ full_matrix @ post_rotate_translation
+ )
+
+ self._atlas_data_layer.data = dask_affine_transform(
+ self._atlas.reference,
+ transform_matrix,
+ order=2,
+ output_shape=bounding_box,
+ output_chunks=(2, bounding_box[1], bounding_box[2]),
+ )
+
+ self._atlas_annotations_layer.data = dask_affine_transform(
+ self._atlas.annotation,
+ transform_matrix,
+ order=0,
+ output_shape=bounding_box,
+ output_chunks=(2, bounding_box[1], bounding_box[2]),
+ )
+
+ # Resets the viewer grid to update the grid to the new atlas
+ self._viewer.reset_view()
+
+ worker = self.compute_atlas_rotation(self._atlas_data_layer.data)
+ worker.returned.connect(self.set_atlas_layer_data)
+ worker.start()
+
+ @thread_worker
+ def compute_atlas_rotation(self, dask_array: da.Array):
+ self.adjust_moving_image_widget.reset_atlas_button.setEnabled(False)
+ self.adjust_moving_image_widget.adjust_atlas_rotation.setEnabled(False)
+
+ computed_array = dask_array.compute()
+
+ self.adjust_moving_image_widget.reset_atlas_button.setEnabled(True)
+ self.adjust_moving_image_widget.adjust_atlas_rotation.setEnabled(True)
+
+ return computed_array
+
+ def set_atlas_layer_data(self, new_data):
+ self._atlas_data_layer.data = new_data
+
+ def _on_atlas_reset(self):
+ if not self._atlas:
+ show_error(
+ "No atlas selected. Please select an atlas before resetting"
+ )
+ return
+
+ self._atlas_data_layer.data = self._atlas.reference
+ self._atlas_annotations_layer.data = self._atlas.annotation
+ self._viewer.grid.enabled = False
+ self._viewer.grid.enabled = True
diff --git a/brainglobe_registration/utils/utils.py b/brainglobe_registration/utils/utils.py
index 48d0933..858ebab 100644
--- a/brainglobe_registration/utils/utils.py
+++ b/brainglobe_registration/utils/utils.py
@@ -1,9 +1,15 @@
from pathlib import Path
-from typing import List
+from typing import Dict, List, Tuple
import napari
import numpy as np
+import numpy.typing as npt
+from brainglobe_atlasapi.list_atlases import get_downloaded_atlases
+from brainglobe_utils.qtpy.dialog import display_info
from pytransform3d.rotations import active_matrix_from_angle
+from qtpy.QtWidgets import QWidget
+from scipy.ndimage import gaussian_filter
+from skimage import morphology
def adjust_napari_image_layer(
@@ -20,7 +26,7 @@ def adjust_napari_image_layer(
https://forum.image.sc/t/napari-3d-rotation-center-change-and-scaling/66347/5
Parameters
- ----------
+ ------------
image_layer : napari.layers.Layer
The napari image layer to be adjusted.
x : int
@@ -31,7 +37,7 @@ def adjust_napari_image_layer(
The angle of rotation in degrees.
Returns
- -------
+ --------
None
"""
image_layer.translate = (y, x)
@@ -46,7 +52,7 @@ def adjust_napari_image_layer(
image_layer.affine = transform_matrix
-def open_parameter_file(file_path: Path) -> dict:
+def open_parameter_file(file_path: Path) -> Dict:
"""
Opens the parameter file and returns the parameter dictionary.
@@ -56,13 +62,13 @@ def open_parameter_file(file_path: Path) -> dict:
The values are stripped of any trailing ")" and leading or trailing quotes.
Parameters
- ----------
+ ------------
file_path : Path
The path to the parameter file.
Returns
- -------
- dict
+ --------
+ Dict
A dictionary containing the parameters from the file.
"""
with open(file_path, "r") as f:
@@ -94,5 +100,157 @@ def find_layer_index(viewer: napari.Viewer, layer_name: str) -> int:
def get_image_layer_names(viewer: napari.Viewer) -> List[str]:
"""
Returns a list of the names of the napari image layers in the viewer.
+
+ Parameters
+ ------------
+ viewer : napari.Viewer
+ The napari viewer containing the image layers.
+
+ Returns
+ --------
+ List[str]
+ A list of the names of the image layers in the viewer.
"""
return [layer.name for layer in viewer.layers]
+
+
+def calculate_rotated_bounding_box(
+ image_shape: Tuple[int, int, int], rotation_matrix: npt.NDArray
+) -> Tuple[int, int, int]:
+ """
+ Calculates the bounding box of the rotated image.
+
+ This function calculates the bounding box of the rotated image given the
+ image shape and rotation matrix. The bounding box is calculated by
+ transforming the corners of the image and finding the minimum and maximum
+ values of the transformed corners.
+
+ Parameters
+ ------------
+ image_shape : Tuple[int, int, int]
+ The shape of the image.
+ rotation_matrix : npt.NDArray
+ The rotation matrix.
+
+ Returns
+ --------
+ Tuple[int, int, int]
+ The bounding box of the rotated image.
+ """
+ corners = np.array(
+ [
+ [0, 0, 0, 1],
+ [image_shape[0], 0, 0, 1],
+ [0, image_shape[1], 0, 1],
+ [0, 0, image_shape[2], 1],
+ [image_shape[0], image_shape[1], 0, 1],
+ [image_shape[0], 0, image_shape[2], 1],
+ [0, image_shape[1], image_shape[2], 1],
+ [image_shape[0], image_shape[1], image_shape[2], 1],
+ ]
+ )
+
+ transformed_corners = rotation_matrix @ corners.T
+ min_corner = np.min(transformed_corners, axis=1)
+ max_corner = np.max(transformed_corners, axis=1)
+
+ return (
+ int(np.round(max_corner[0] - min_corner[0])),
+ int(np.round(max_corner[1] - min_corner[1])),
+ int(np.round(max_corner[2] - min_corner[2])),
+ )
+
+
+def check_atlas_installed(parent_widget: QWidget):
+ """
+ Function checks if user has any atlases installed. If not, message box
+ appears in napari, directing user to download atlases via attached links.
+ """
+ available_atlases = get_downloaded_atlases()
+ if len(available_atlases) == 0:
+ display_info(
+ widget=parent_widget,
+ title="Information",
+ message="No atlases available. Please download atlas(es) "
+ "using "
+ "brainglobe-atlasapi or brainrender-napari",
+ )
+
+
+def filter_plane(img_plane):
+ """
+ Apply a set of filter to the plane (typically to avoid overfitting details
+ in the image during registration)
+ The filter is composed of a despeckle filter using opening and a pseudo
+ flatfield filter
+
+ Originally from: [https://github.com/brainglobe/brainreg/blob/main
+ /brainreg/core/utils/preprocess.py]
+
+ Parameters
+ ----------
+ img_plane : np.array
+ A 2D array to filter
+
+ Returns
+ -------
+ np.array
+ Filtered image
+ """
+
+ img_plane = despeckle_by_opening(img_plane)
+ img_plane = pseudo_flatfield(img_plane)
+ return img_plane
+
+
+def despeckle_by_opening(img_plane, radius=2):
+ """
+ Despeckle the image plane using a grayscale opening operation
+
+ Originally from: [https://github.com/brainglobe/brainreg/blob/main
+ /brainreg/core/utils/preprocess.py]
+
+ Parameters
+ ----------
+ img_plane : np.array
+ The image to filter
+
+ radius: int
+ The radius of the opening kernel
+
+ Returns
+ -------
+ np.array
+ The despeckled image
+ """
+ kernel = morphology.disk(radius)
+ morphology.opening(img_plane, out=img_plane, footprint=kernel)
+ return img_plane
+
+
+def pseudo_flatfield(img_plane, sigma=5):
+ """
+ Pseudo flat field filter implementation using a de-trending by a
+ heavily gaussian filtered copy of the image.
+
+ Originally from: [https://github.com/brainglobe/brainreg/blob/main
+ /brainreg/core/utils/preprocess.py]
+
+ Parameters
+ ----------
+ img_plane : np.array
+ The image to filter
+
+ sigma : int
+ The sigma of the gaussian filter applied to the
+ image used for de-trending
+
+ Returns
+ -------
+ np.array
+ The pseudo flat field filtered image
+ """
+ filtered_img = gaussian_filter(img_plane, sigma)
+ return img_plane / (filtered_img + 1)
diff --git a/brainglobe_registration/widgets/adjust_moving_image_view.py b/brainglobe_registration/widgets/adjust_moving_image_view.py
index 4740e25..111951d 100644
--- a/brainglobe_registration/widgets/adjust_moving_image_view.py
+++ b/brainglobe_registration/widgets/adjust_moving_image_view.py
@@ -1,4 +1,4 @@
-from qtpy.QtCore import Signal
+from qtpy.QtCore import Qt, Signal
from qtpy.QtWidgets import (
QDoubleSpinBox,
QFormLayout,
@@ -14,8 +14,8 @@ class AdjustMovingImageView(QWidget):
A QWidget subclass that provides controls for adjusting the moving image.
This widget provides controls for adjusting the x and y offsets and
- rotation of the moving image. It emits signals when the image is adjusted
- or reset.
+ rotation of the moving image. It emits signals when the image is adjusted,
+ scaled, or reset.
Attributes
----------
@@ -33,9 +33,18 @@ class AdjustMovingImageView(QWidget):
_on_reset_image_button_click():
Resets the x and y offsets and rotation to 0 and emits the
reset_image_signal.
+ _on_scale_image_button_click():
+ Emits the scale_image_signal with the entered pixel sizes.
+ _on_adjust_atlas_rotation():
+ Emits the atlas_rotation_signal with the entered pitch, yaw, and roll.
+ _on_atlas_reset():
+ Resets the pitch, yaw, and roll to 0 and emits the atlas_reset_signal.
"""
adjust_image_signal = Signal(int, int, float)
+ scale_image_signal = Signal(float, float)
+ atlas_rotation_signal = Signal(float, float, float)
+ reset_atlas_signal = Signal()
reset_image_signal = Signal()
def __init__(self, parent=None):
@@ -50,19 +59,53 @@ def __init__(self, parent=None):
super().__init__(parent=parent)
self.setLayout(QFormLayout())
+ self.layout().setLabelAlignment(Qt.AlignLeft)
offset_range = 2000
rotation_range = 360
- self.adjust_moving_image_x = QSpinBox()
+ self.adjust_moving_image_pixel_size_x = QDoubleSpinBox(parent=self)
+ self.adjust_moving_image_pixel_size_x.setDecimals(3)
+ self.adjust_moving_image_pixel_size_x.setRange(0.001, 100.00)
+ self.adjust_moving_image_pixel_size_y = QDoubleSpinBox(parent=self)
+ self.adjust_moving_image_pixel_size_y.setDecimals(3)
+ self.adjust_moving_image_pixel_size_y.setRange(0.001, 100.00)
+ self.scale_moving_image_button = QPushButton()
+ self.scale_moving_image_button.setText("Scale Image")
+ self.scale_moving_image_button.clicked.connect(
+ self._on_scale_image_button_click
+ )
+
+ self.adjust_atlas_pitch = QDoubleSpinBox(parent=self)
+ self.adjust_atlas_pitch.setSingleStep(0.1)
+ self.adjust_atlas_pitch.setRange(-rotation_range, rotation_range)
+
+ self.adjust_atlas_yaw = QDoubleSpinBox(parent=self)
+ self.adjust_atlas_yaw.setSingleStep(0.1)
+ self.adjust_atlas_yaw.setRange(-rotation_range, rotation_range)
+
+ self.adjust_atlas_roll = QDoubleSpinBox(parent=self)
+ self.adjust_atlas_roll.setSingleStep(0.1)
+ self.adjust_atlas_roll.setRange(-rotation_range, rotation_range)
+
+ self.adjust_atlas_rotation = QPushButton()
+ self.adjust_atlas_rotation.setText("Rotate Atlas")
+ self.adjust_atlas_rotation.clicked.connect(
+ self._on_adjust_atlas_rotation
+ )
+ self.reset_atlas_button = QPushButton()
+ self.reset_atlas_button.setText("Reset Atlas")
+ self.reset_atlas_button.clicked.connect(self._on_atlas_reset)
+
+ self.adjust_moving_image_x = QSpinBox(parent=self)
self.adjust_moving_image_x.setRange(-offset_range, offset_range)
self.adjust_moving_image_x.valueChanged.connect(self._on_adjust_image)
- self.adjust_moving_image_y = QSpinBox()
+ self.adjust_moving_image_y = QSpinBox(parent=self)
self.adjust_moving_image_y.setRange(-offset_range, offset_range)
self.adjust_moving_image_y.valueChanged.connect(self._on_adjust_image)
- self.adjust_moving_image_rotate = QDoubleSpinBox()
+ self.adjust_moving_image_rotate = QDoubleSpinBox(parent=self)
self.adjust_moving_image_rotate.setRange(
-rotation_range, rotation_range
)
@@ -71,13 +114,31 @@ def __init__(self, parent=None):
self._on_adjust_image
)
- self.adjust_moving_image_reset_button = QPushButton(parent=self)
+ self.adjust_moving_image_reset_button = QPushButton()
self.adjust_moving_image_reset_button.setText("Reset Image")
self.adjust_moving_image_reset_button.clicked.connect(
self._on_reset_image_button_click
)
- self.layout().addRow(QLabel("Adjust the moving image: "))
+ self.layout().addRow(QLabel("Adjust the moving image scale:"))
+ self.layout().addRow(
+ "Sample image X pixel size (\u03BCm / pixel):",
+ self.adjust_moving_image_pixel_size_x,
+ )
+ self.layout().addRow(
+ "Sample image Y pixel size (\u03BCm / pixel):",
+ self.adjust_moving_image_pixel_size_y,
+ )
+ self.layout().addRow(self.scale_moving_image_button)
+
+ self.layout().addRow(QLabel("Adjust the atlas pitch and yaw: "))
+ self.layout().addRow("Pitch:", self.adjust_atlas_pitch)
+ self.layout().addRow("Yaw:", self.adjust_atlas_yaw)
+ self.layout().addRow("Roll:", self.adjust_atlas_roll)
+ self.layout().addRow(self.adjust_atlas_rotation)
+ self.layout().addRow(self.reset_atlas_button)
+
+ self.layout().addRow(QLabel("Adjust the moving image position: "))
self.layout().addRow("X offset:", self.adjust_moving_image_x)
self.layout().addRow("Y offset:", self.adjust_moving_image_y)
self.layout().addRow(
@@ -106,3 +167,32 @@ def _on_reset_image_button_click(self):
self.adjust_moving_image_rotate.setValue(0)
self.reset_image_signal.emit()
+
+ def _on_scale_image_button_click(self):
+ """
+ Emit the scale_image_signal with the entered pixel sizes.
+ """
+ self.scale_image_signal.emit(
+ self.adjust_moving_image_pixel_size_x.value(),
+ self.adjust_moving_image_pixel_size_y.value(),
+ )
+
+ def _on_adjust_atlas_rotation(self):
+ """
+ Emit the atlas_rotation_signal with the entered pitch and yaw.
+ """
+ self.atlas_rotation_signal.emit(
+ self.adjust_atlas_pitch.value(),
+ self.adjust_atlas_yaw.value(),
+ self.adjust_atlas_roll.value(),
+ )
+
+ def _on_atlas_reset(self):
+ """
+ Reset the pitch, yaw, and roll to 0 and emit the atlas_reset_signal.
+ """
+ self.adjust_atlas_yaw.setValue(0)
+ self.adjust_atlas_pitch.setValue(0)
+ self.adjust_atlas_roll.setValue(0)
+
+ self.reset_atlas_signal.emit()
diff --git a/brainglobe_registration/widgets/select_images_view.py b/brainglobe_registration/widgets/select_images_view.py
index 17bc737..ad9046c 100644
--- a/brainglobe_registration/widgets/select_images_view.py
+++ b/brainglobe_registration/widgets/select_images_view.py
@@ -98,7 +98,8 @@ def __init__(
def update_sample_image_names(self, sample_image_names: List[str]):
self.available_sample_dropdown.clear()
- self.available_sample_dropdown.addItems(sample_image_names)
+ if len(sample_image_names) != 0:
+ self.available_sample_dropdown.addItems(sample_image_names)
def _on_atlas_index_change(self):
"""
@@ -122,5 +123,9 @@ def _on_moving_image_index_change(self):
self.available_sample_dropdown.currentIndex()
)
+ def reset_atlas_combobox(self):
+ if self.available_atlas_dropdown is not None:
+ self.available_atlas_dropdown.setCurrentIndex(0)
+
def _on_sample_popup_about_to_show(self):
self.sample_image_popup_about_to_show.emit()
diff --git a/brainglobe_registration/widgets/transform_select_view.py b/brainglobe_registration/widgets/transform_select_view.py
index eba8c32..d141816 100644
--- a/brainglobe_registration/widgets/transform_select_view.py
+++ b/brainglobe_registration/widgets/transform_select_view.py
@@ -169,7 +169,7 @@ def _on_transform_type_change(self, index):
self.file_selections[index].setCurrentIndex(0)
if index >= len(self.transform_type_selections) - 1:
- curr_length = self.rowCount()
+ current_length = self.rowCount()
self.setRowCount(self.rowCount() + 1)
self.transform_type_selections.append(QComboBox())
@@ -180,7 +180,7 @@ def _on_transform_type_change(self, index):
self.transform_type_signaller.map
)
self.transform_type_signaller.setMapping(
- self.transform_type_selections[-1], curr_length
+ self.transform_type_selections[-1], current_length
)
self.file_selections.append(QComboBox())
@@ -191,14 +191,16 @@ def _on_transform_type_change(self, index):
self.file_signaller.map
)
self.file_signaller.setMapping(
- self.file_selections[-1], curr_length
+ self.file_selections[-1], current_length
)
self.setCellWidget(
- curr_length, 0, self.transform_type_selections[curr_length]
+ current_length,
+ 0,
+ self.transform_type_selections[current_length],
)
self.setCellWidget(
- curr_length, 1, self.file_selections[curr_length]
+ current_length, 1, self.file_selections[current_length]
)
else:
diff --git a/pyproject.toml b/pyproject.toml
index 9317c22..0d14999 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,8 @@ dependencies = [
"napari>=0.4.18",
"brainglobe-atlasapi",
"brainglobe-utils>=0.4.3",
+ "dask",
+ "dask-image",
"itk-elastix",
"lxml_html_clean",
"numpy",
diff --git a/tests/test_adjust_moving_image_view.py b/tests/test_adjust_moving_image_view.py
index e56107e..8eae56c 100644
--- a/tests/test_adjust_moving_image_view.py
+++ b/tests/test_adjust_moving_image_view.py
@@ -14,6 +14,12 @@ def adjust_moving_image_view() -> AdjustMovingImageView:
return adjust_moving_image_view
+def test_init(qtbot, adjust_moving_image_view):
+ qtbot.addWidget(adjust_moving_image_view)
+
+ assert adjust_moving_image_view.layout().rowCount() == 15
+
+
@pytest.mark.parametrize(
"x_value, expected",
[
@@ -94,3 +100,62 @@ def test_reset_image_button_click(qtbot, adjust_moving_image_view):
assert adjust_moving_image_view.adjust_moving_image_x.value() == 0
assert adjust_moving_image_view.adjust_moving_image_y.value() == 0
assert adjust_moving_image_view.adjust_moving_image_rotate.value() == 0
+
+
+@pytest.mark.parametrize(
+ "x_scale, y_scale",
+ [(2.5, 2.5), (10, 20), (10.2212, 10.2289)],
+)
+def test_scale_image_button_click(
+ qtbot, adjust_moving_image_view, x_scale, y_scale
+):
+ qtbot.addWidget(adjust_moving_image_view)
+
+ with qtbot.waitSignal(
+ adjust_moving_image_view.scale_image_signal, timeout=1000
+ ) as blocker:
+ adjust_moving_image_view.adjust_moving_image_pixel_size_x.setValue(
+ x_scale
+ )
+ adjust_moving_image_view.adjust_moving_image_pixel_size_y.setValue(
+ y_scale
+ )
+ adjust_moving_image_view.scale_moving_image_button.click()
+
+ assert blocker.args == [round(x_scale, 3), round(y_scale, 3)]
+
+
+def test_atlas_rotation_changed(
+ qtbot,
+ adjust_moving_image_view,
+):
+ qtbot.addWidget(adjust_moving_image_view)
+
+ with qtbot.waitSignal(
+ adjust_moving_image_view.atlas_rotation_signal, timeout=1000
+ ) as blocker:
+ adjust_moving_image_view.adjust_atlas_pitch.setValue(10)
+ adjust_moving_image_view.adjust_atlas_yaw.setValue(20)
+ adjust_moving_image_view.adjust_atlas_roll.setValue(30)
+
+ adjust_moving_image_view.adjust_atlas_rotation.click()
+
+ assert blocker.args == [10, 20, 30]
+
+
+def test_atlas_reset_button_click(
+ qtbot,
+ adjust_moving_image_view,
+):
+ qtbot.addWidget(adjust_moving_image_view)
+
+ with qtbot.waitSignal(
+ adjust_moving_image_view.reset_atlas_signal, timeout=1000
+ ):
+ adjust_moving_image_view.reset_atlas_button.click()
+
+ assert (
+ adjust_moving_image_view.adjust_atlas_pitch.value() == 0
+ and adjust_moving_image_view.adjust_atlas_yaw.value() == 0
+ and adjust_moving_image_view.adjust_atlas_roll.value() == 0
+ )
diff --git a/tests/test_parameter_list_view.py b/tests/test_parameter_list_view.py
index 00a701f..cabbfce 100644
--- a/tests/test_parameter_list_view.py
+++ b/tests/test_parameter_list_view.py
@@ -56,7 +56,7 @@ def test_parameter_list_view_cell_change(parameter_list_view, qtbot):
def test_parameter_list_view_cell_change_last_row(parameter_list_view, qtbot):
qtbot.addWidget(parameter_list_view)
- curr_row_count = parameter_list_view.rowCount()
+ current_row_count = parameter_list_view.rowCount()
last_row_index = len(param_dict)
parameter_list_view.setItem(
@@ -65,7 +65,7 @@ def test_parameter_list_view_cell_change_last_row(parameter_list_view, qtbot):
parameter_list_view.setItem(last_row_index, 1, QTableWidgetItem("true"))
assert parameter_list_view.param_dict["TestParameter"] == ["true"]
- assert parameter_list_view.rowCount() == curr_row_count + 1
+ assert parameter_list_view.rowCount() == current_row_count + 1
def test_parameter_list_view_cell_change_last_row_no_param(
@@ -73,12 +73,12 @@ def test_parameter_list_view_cell_change_last_row_no_param(
):
qtbot.addWidget(parameter_list_view)
- curr_row_count = parameter_list_view.rowCount()
+ current_row_count = parameter_list_view.rowCount()
last_row_index = len(param_dict)
parameter_list_view.setItem(last_row_index, 1, QTableWidgetItem("true"))
- assert parameter_list_view.rowCount() == curr_row_count
+ assert parameter_list_view.rowCount() == current_row_count
def test_parameter_list_view_cell_change_last_row_no_value(
@@ -86,11 +86,11 @@ def test_parameter_list_view_cell_change_last_row_no_value(
):
qtbot.addWidget(parameter_list_view)
- curr_row_count = parameter_list_view.rowCount()
+ current_row_count = parameter_list_view.rowCount()
last_row_index = len(param_dict)
parameter_list_view.setItem(
last_row_index, 0, QTableWidgetItem("TestParameter")
)
- assert parameter_list_view.rowCount() == curr_row_count
+ assert parameter_list_view.rowCount() == current_row_count
diff --git a/tests/test_registration_widget.py b/tests/test_registration_widget.py
index aa91594..828c4aa 100644
--- a/tests/test_registration_widget.py
+++ b/tests/test_registration_widget.py
@@ -12,6 +12,29 @@ def registration_widget(make_napari_viewer_with_images):
return widget
+@pytest.fixture()
+def registration_widget_with_example_atlas(make_napari_viewer_with_images):
+ """
+ Create an initialised RegistrationWidget with the "example_mouse_100um"
+ loaded.
+
+ Parameters
+ ------------
+ make_napari_viewer_with_images for testing
+ Fixture that creates a napari viewer
+ """
+ viewer = make_napari_viewer_with_images
+
+ widget = RegistrationWidget(viewer)
+
+ # Based on the downloaded atlases by the fixture in conftest.py,
+ # the example atlas will be in the third position.
+ example_atlas_index = 2
+ widget._on_atlas_dropdown_index_changed(example_atlas_index)
+
+ return widget
+
+
def test_registration_widget(make_napari_viewer_with_images):
widget = RegistrationWidget(make_napari_viewer_with_images)
@@ -65,3 +88,129 @@ def test_sample_dropdown_index_changed_with_valid_index(
registration_widget._moving_image.name
== registration_widget._sample_images[1]
)
+
+
+def test_scale_moving_image_no_atlas(
+ make_napari_viewer_with_images, registration_widget, mocker
+):
+ mocked_show_error = mocker.patch(
+ "brainglobe_registration.registration_widget.show_error"
+ )
+ registration_widget._atlas = None
+ registration_widget.adjust_moving_image_widget.scale_image_signal.emit(
+ 10, 10
+ )
+ mocked_show_error.assert_called_once_with(
+ "Sample image or atlas not selected. "
+ "Please select a sample image and atlas before scaling"
+ )
+
+
+def test_scale_moving_image_no_sample_image(
+ make_napari_viewer_with_images, registration_widget, mocker
+):
+ mocked_show_error = mocker.patch(
+ "brainglobe_registration.registration_widget.show_error"
+ )
+ registration_widget._moving_image = None
+ registration_widget.adjust_moving_image_widget.scale_image_signal.emit(
+ 10, 10
+ )
+ mocked_show_error.assert_called_once_with(
+ "Sample image or atlas not selected. "
+ "Please select a sample image and atlas before scaling"
+ )
+
+
+@pytest.mark.parametrize(
+ "x_scale_factor, y_scale_factor",
+ [
+ (0.5, 0.5),
+ (1.0, 1.0),
+ (2.0, 2.0),
+ (0.5, 1.0),
+ (1.0, 0.5),
+ ],
+)
+def test_scale_moving_image(
+ make_napari_viewer_with_images,
+ registration_widget,
+ mocker,
+ x_scale_factor,
+ y_scale_factor,
+):
+ mock_atlas = mocker.patch(
+ "brainglobe_registration.registration_widget.BrainGlobeAtlas"
+ )
+ mock_atlas.resolution = [20, 20, 20]
+ registration_widget._atlas = mock_atlas
+
+ current_size = registration_widget._moving_image.data.shape
+ registration_widget.adjust_moving_image_widget.scale_image_signal.emit(
+ mock_atlas.resolution[1] * x_scale_factor,
+ mock_atlas.resolution[2] * y_scale_factor,
+ )
+
+ assert registration_widget._moving_image.data.shape == (
+ current_size[0] * y_scale_factor,
+ current_size[1] * x_scale_factor,
+ )
+
+
+@pytest.mark.parametrize(
+ "pitch, yaw, roll, expected_shape",
+ [
+ (0, 0, 0, (132, 80, 114)),
+ (45, 0, 0, (150, 150, 114)),
+ (0, 45, 0, (174, 80, 174)),
+ (0, 0, 45, (132, 137, 137)),
+ ],
+)
+def test_on_adjust_atlas_rotation(
+ registration_widget_with_example_atlas,
+ pitch,
+ yaw,
+ roll,
+ expected_shape,
+):
+ reg_widget = registration_widget_with_example_atlas
+ atlas_shape = reg_widget._atlas.reference.shape
+
+ reg_widget._on_adjust_atlas_rotation(pitch, yaw, roll)
+
+ assert reg_widget._atlas_data_layer.data.shape == expected_shape
+ assert reg_widget._atlas_annotations_layer.data.shape == expected_shape
+ assert reg_widget._atlas.reference.shape == atlas_shape
+
+
+def test_on_adjust_atlas_rotation_no_atlas(registration_widget, mocker):
+ mocked_show_error = mocker.patch(
+ "brainglobe_registration.registration_widget.show_error"
+ )
+ registration_widget._on_adjust_atlas_rotation(10, 10, 10)
+ mocked_show_error.assert_called_once_with(
+ "No atlas selected. Please select an atlas before rotating"
+ )
+
+
+def test_on_atlas_reset(registration_widget_with_example_atlas):
+ reg_widget = registration_widget_with_example_atlas
+ atlas_shape = reg_widget._atlas.reference.shape
+ reg_widget._on_adjust_atlas_rotation(10, 10, 10)
+
+ reg_widget._on_atlas_reset()
+
+ assert reg_widget._atlas_data_layer.data.shape == atlas_shape
+ assert reg_widget._atlas.reference.shape == atlas_shape
+ assert reg_widget._atlas_annotations_layer.data.shape == atlas_shape
+
+
+def test_on_atlas_reset_no_atlas(registration_widget, mocker):
+ mocked_show_error = mocker.patch(
+ "brainglobe_registration.registration_widget.show_error"
+ )
+
+ registration_widget._on_atlas_reset()
+ mocked_show_error.assert_called_once_with(
+ "No atlas selected. Please select an atlas before resetting"
+ )
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d857e09..89068e7 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -7,6 +7,7 @@
from brainglobe_registration.utils.utils import (
adjust_napari_image_layer,
+ calculate_rotated_bounding_box,
find_layer_index,
get_image_layer_names,
open_parameter_file,
@@ -93,3 +94,32 @@ def test_get_image_layer_names(make_napari_viewer_with_images):
assert len(layer_names) == 2
assert layer_names == ["moving_image", "atlas_image"]
+
+
+@pytest.mark.parametrize(
+ "basis, rotation, expected_bounds",
+ [
+ (0, 0, (50, 100, 200)),
+ (0, 90, (50, 200, 100)),
+ (0, 180, (50, 100, 200)),
+ (0, 45, (50, 212, 212)),
+ (1, 0, (50, 100, 200)),
+ (1, 90, (200, 100, 50)),
+ (1, 180, (50, 100, 200)),
+ (1, 45, (177, 100, 177)),
+ (2, 0, (50, 100, 200)),
+ (2, 90, (100, 50, 200)),
+ (2, 180, (50, 100, 200)),
+ (2, 45, (106, 106, 200)),
+ ],
+)
+def test_calculate_rotated_bounding_box(basis, rotation, expected_bounds):
+ image_shape = (50, 100, 200)
+ rotation_matrix = np.eye(4)
+ rotation_matrix[:-1, :-1] = active_matrix_from_angle(
+ basis, np.deg2rad(rotation)
+ )
+
+ result_shape = calculate_rotated_bounding_box(image_shape, rotation_matrix)
+
+ assert result_shape == expected_bounds