diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 08893b8..9546c84 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -49,6 +49,23 @@ jobs: with: qt: true + # cache atlases needed by the tests + - name: Cache Atlases + id: atlas-cache + uses: actions/cache@v3 + with: + path: | # ensure we don't cache any interrupted atlas download and extraction! + ~/.brainglobe/* + !~/.brainglobe/atlas.tar.gz + key: ${{ runner.os }}-cached-atlases + enableCrossOsArchive: false # ~ and $HOME evaluate to different places across OSs! + + - if: ${{ steps.atlas-cache.outputs.cache-hit == 'true' }} + name: List files in brainglobe data folder # good to be able to sanity check that user data is as expected + run: | + ls -af ~/.brainglobe/ + + # Run tests - uses: neuroinformatics-unit/actions/test@v2 with: diff --git a/.gitignore b/.gitignore index 73d56d3..08e97e3 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ venv/ # written by setuptools_scm **/_version.py + +# Elastix related files +/Logs/ diff --git a/brainglobe_registration/elastix/register.py b/brainglobe_registration/elastix/register.py index d04d918..4a216f9 100644 --- a/brainglobe_registration/elastix/register.py +++ b/brainglobe_registration/elastix/register.py @@ -1,7 +1,8 @@ -from typing import List +from typing import List, Tuple import itk import numpy as np +import numpy.typing as npt from brainglobe_atlasapi import BrainGlobeAtlas @@ -28,28 +29,30 @@ def run_registration( atlas_image, moving_image, annotation_image, - parameter_lists: List[tuple[str, dict]], -) -> tuple[np.ndarray, itk.ParameterObject, np.ndarray]: + parameter_lists: List[Tuple[str, dict]], +) -> Tuple[npt.NDArray, itk.ParameterObject, npt.NDArray]: """ Run the registration process on the given images. Parameters ---------- - atlas_image : np.ndarray + atlas_image : npt.NDArray The atlas image. - moving_image : np.ndarray + moving_image : npt.NDArray The moving image. - annotation_image : np.ndarray + annotation_image : npt.NDArray The annotation image. parameter_lists : List[tuple[str, dict]], optional The list of parameter lists, by default None Returns ------- - np.ndarray + npt.NDArray The result image. itk.ParameterObject The result transform parameters. + npt.NDArray + The transformed annotation image. """ # convert to ITK, view only atlas_image = itk.GetImageViewFromArray(atlas_image).astype(itk.F) diff --git a/brainglobe_registration/registration_widget.py b/brainglobe_registration/registration_widget.py index d389d0c..4d0f367 100644 --- a/brainglobe_registration/registration_widget.py +++ b/brainglobe_registration/registration_widget.py @@ -9,25 +9,32 @@ """ from pathlib import Path +from typing import Optional +import dask.array as da +import napari.layers import numpy as np +import numpy.typing as npt from brainglobe_atlasapi import BrainGlobeAtlas from brainglobe_atlasapi.list_atlases import get_downloaded_atlases +from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer +from brainglobe_utils.qtpy.dialog import display_info from brainglobe_utils.qtpy.logo import header_widget +from dask_image.ndinterp import affine_transform as dask_affine_transform +from napari.qt.threading import thread_worker +from napari.utils.events import Event +from napari.utils.notifications import show_error from napari.viewer import Viewer -from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QGroupBox, - QPushButton, - QTabWidget, - QVBoxLayout, - QWidget, -) +from pytransform3d.rotations import active_matrix_from_angle +from qtpy.QtWidgets import QCheckBox, QPushButton, QTabWidget from skimage.segmentation import find_boundaries +from skimage.transform import rescale -from brainglobe_registration.elastix.register import run_registration from brainglobe_registration.utils.utils import ( adjust_napari_image_layer, + calculate_rotated_bounding_box, + check_atlas_installed, + filter_plane, find_layer_index, get_image_layer_names, open_parameter_file, @@ -44,13 +51,19 @@ ) -class RegistrationWidget(QWidget): +class RegistrationWidget(CollapsibleWidgetContainer): def __init__(self, napari_viewer: Viewer): super().__init__() + self.setContentsMargins(10, 10, 10, 10) self._viewer = napari_viewer - self._atlas: BrainGlobeAtlas = None - self._moving_image = None + self._atlas: Optional[BrainGlobeAtlas] = None + self._atlas_data_layer: Optional[napari.layers.Image] = None + self._atlas_annotations_layer: Optional[napari.layers.Labels] = None + self._moving_image: Optional[napari.layers.Image] = None + self._moving_image_data_backup: Optional[npt.NDArray] = None + self._automatic_deletion_flag = False + # Flag to differentiate between manual and automatic atlas deletion self.transform_params: dict[str, dict] = { "rigid": {}, @@ -84,22 +97,6 @@ def __init__(self, napari_viewer: Viewer): else: self._moving_image = None - self.setLayout(QVBoxLayout()) - self.layout().addWidget( - header_widget( - "brainglobe-
registration", # line break at
- "Registration with Elastix", - github_repo_name="brainglobe-registration", - ) - ) - - self.main_tabs = QTabWidget(parent=self) - self.main_tabs.setTabPosition(QTabWidget.West) - - self.settings_tab = QGroupBox() - self.settings_tab.setLayout(QVBoxLayout()) - self.parameters_tab = QTabWidget() - self.get_atlas_widget = SelectImagesView( available_atlases=self._available_atlases, sample_image_names=self._sample_images, @@ -119,10 +116,18 @@ def __init__(self, napari_viewer: Viewer): self.adjust_moving_image_widget.adjust_image_signal.connect( self._on_adjust_moving_image ) - self.adjust_moving_image_widget.reset_image_signal.connect( self._on_adjust_moving_image_reset_button_click ) + self.adjust_moving_image_widget.scale_image_signal.connect( + self._on_scale_moving_image + ) + self.adjust_moving_image_widget.atlas_rotation_signal.connect( + self._on_adjust_atlas_rotation + ) + self.adjust_moving_image_widget.reset_atlas_signal.connect( + self._on_atlas_reset + ) self.transform_select_view = TransformSelectView() self.transform_select_view.transform_type_added_signal.connect( @@ -135,17 +140,33 @@ def __init__(self, napari_viewer: Viewer): self._on_default_file_selection_change ) + # Use decorator to connect to layer deletion event + self._connect_events() + self.filter_checkbox = QCheckBox("Filter Images") + self.filter_checkbox.setChecked(False) + self.run_button = QPushButton("Run") self.run_button.clicked.connect(self._on_run_button_click) self.run_button.setEnabled(False) - self.settings_tab.layout().addWidget(self.get_atlas_widget) - self.settings_tab.layout().addWidget(self.adjust_moving_image_widget) - self.settings_tab.layout().addWidget(self.transform_select_view) - self.settings_tab.layout().addWidget(self.run_button) - self.settings_tab.layout().setAlignment(Qt.AlignTop) + self.add_widget( + header_widget( + "brainglobe-
registration", # line break at
+ "Registration with Elastix", + github_repo_name="brainglobe-registration", + ), + collapsible=False, + ) + self.add_widget(self.get_atlas_widget, widget_title="Select Images") + self.add_widget( + self.adjust_moving_image_widget, widget_title="Prepare Images" + ) + self.add_widget( + self.transform_select_view, widget_title="Select Transformations" + ) self.parameter_setting_tabs_lists = [] + self.parameters_tab = QTabWidget(parent=self) for transform_type in self.transform_params: new_tab = RegistrationParameterListView( @@ -156,67 +177,188 @@ def __init__(self, napari_viewer: Viewer): self.parameters_tab.addTab(new_tab, transform_type) self.parameter_setting_tabs_lists.append(new_tab) - self.main_tabs.addTab(self.settings_tab, "Settings") - self.main_tabs.addTab(self.parameters_tab, "Parameters") + self.add_widget( + self.parameters_tab, widget_title="Advanced Settings (optional)" + ) + + self.add_widget(self.filter_checkbox, collapsible=False) + + self.add_widget(self.run_button, collapsible=False) - self.layout().addWidget(self.main_tabs) + self.layout().itemAt(1).widget().collapse(animate=False) + + check_atlas_installed(self) + + def _connect_events(self): + @self._viewer.layers.events.removed.connect + def _on_layer_deleted(event: Event): + if not self._automatic_deletion_flag: + self._handle_layer_deletion(event) + + def _handle_layer_deletion(self, event: Event): + deleted_layer = event.value + + # Check if the deleted layer is the moving image + if self._moving_image == deleted_layer: + self._moving_image = None + self._moving_image_data_backup = None + self._update_dropdowns() + + # Check if deleted layer is the atlas reference / atlas annotations + if ( + self._atlas_data_layer == deleted_layer + or self._atlas_annotations_layer == deleted_layer + ): + # Reset the atlas selection combobox + self.get_atlas_widget.reset_atlas_combobox() + + def _delete_atlas_layers(self): + # Delete atlas reference layer if it exists + if self._atlas_data_layer in self._viewer.layers: + self._viewer.layers.remove(self._atlas_data_layer) + + # Delete atlas annotations layer if it exists + if self._atlas_annotations_layer in self._viewer.layers: + self._viewer.layers.remove(self._atlas_annotations_layer) + + # Clear atlas attributes + self._atlas = None + self._atlas_data_layer = None + self._atlas_annotations_layer = None + self.run_button.setEnabled(False) + self._viewer.grid.enabled = False + + def _update_dropdowns(self): + # Extract the names of the remaining layers + layer_names = get_image_layer_names(self._viewer) + # Update the dropdowns in SelectImagesView + self.get_atlas_widget.update_sample_image_names(layer_names) def _on_atlas_dropdown_index_changed(self, index): # Hacky way of having an empty first dropdown if index == 0: - if self._atlas: - curr_atlas_layer_index = find_layer_index( - self._viewer, self._atlas.atlas_name - ) - - self._viewer.layers.pop(curr_atlas_layer_index) - self._atlas = None - self.run_button.setEnabled(False) - self._viewer.grid.enabled = False - + self._delete_atlas_layers() return atlas_name = self._available_atlases[index] - atlas = BrainGlobeAtlas(atlas_name) if self._atlas: - curr_atlas_layer_index = find_layer_index( - self._viewer, self._atlas.atlas_name - ) - - self._viewer.layers.pop(curr_atlas_layer_index) - else: - self.run_button.setEnabled(True) + # Set flag to True when atlas is changed, not manually deleted + # Ensures atlas index does not reset to 0 + self._automatic_deletion_flag = True + try: + self._delete_atlas_layers() + finally: + self._automatic_deletion_flag = False + + self.run_button.setEnabled(True) + + self._atlas = BrainGlobeAtlas(atlas_name) + dask_reference = da.from_array( + self._atlas.reference, + chunks=( + 1, + self._atlas.reference.shape[1], + self._atlas.reference.shape[2], + ), + ) + dask_annotations = da.from_array( + self._atlas.annotation, + chunks=( + 1, + self._atlas.annotation.shape[1], + self._atlas.annotation.shape[2], + ), + ) - self._viewer.add_image( - atlas.reference, + contrast_max = np.max( + dask_reference[dask_reference.shape[0] // 2] + ).compute() + self._atlas_data_layer = self._viewer.add_image( + dask_reference, name=atlas_name, colormap="gray", blending="translucent", + contrast_limits=[0, contrast_max], + multiscale=False, + ) + self._atlas_annotations_layer = self._viewer.add_labels( + dask_annotations, + name="Annotations", + visible=False, ) - self._atlas = BrainGlobeAtlas(atlas_name=atlas_name) self._viewer.grid.enabled = True def _on_sample_dropdown_index_changed(self, index): viewer_index = find_layer_index( self._viewer, self._sample_images[index] ) + if viewer_index == -1: + self._moving_image = None + self._moving_image_data_backup = None + + return + self._moving_image = self._viewer.layers[viewer_index] + self._moving_image_data_backup = self._moving_image.data.copy() def _on_adjust_moving_image(self, x: int, y: int, rotate: float): + if not self._moving_image: + show_error( + "No moving image selected." + "Please select a moving image before adjusting" + ) + return adjust_napari_image_layer(self._moving_image, x, y, rotate) def _on_adjust_moving_image_reset_button_click(self): + if not self._moving_image: + show_error( + "No moving image selected. " + "Please select a moving image before adjusting" + ) + return adjust_napari_image_layer(self._moving_image, 0, 0, 0) def _on_run_button_click(self): + from brainglobe_registration.elastix.register import run_registration + + if self._atlas_data_layer is None: + display_info( + widget=self, + title="Warning", + message="Please select an atlas before clicking 'Run'.", + ) + return + + if self._moving_image == self._atlas_data_layer: + display_info( + widget=self, + title="Warning", + message="Your moving image cannot be an atlas.", + ) + return + current_atlas_slice = self._viewer.dims.current_step[0] + if self.filter_checkbox.isChecked(): + atlas_layer = filter_plane( + self._atlas_data_layer.data[ + current_atlas_slice, :, : + ].compute() + ) + moving_image_layer = filter_plane(self._moving_image.data) + else: + atlas_layer = self._atlas_data_layer.data[ + current_atlas_slice, :, : + ] + moving_image_layer = self._moving_image.data + result, parameters, registered_annotation_image = run_registration( - self._atlas.reference[current_atlas_slice, :, :], - self._moving_image.data, - self._atlas.annotation[current_atlas_slice, :, :], + atlas_layer, + moving_image_layer, + self._atlas_annotations_layer.data[current_atlas_slice, :, :], self.transform_selections, ) @@ -313,3 +455,133 @@ def _on_default_file_selection_change( def _on_sample_popup_about_to_show(self): self._sample_images = get_image_layer_names(self._viewer) self.get_atlas_widget.update_sample_image_names(self._sample_images) + + def _on_scale_moving_image(self, x: float, y: float): + """ + Scale the moving image to have resolution equal to the atlas. + + Parameters + ------------ + x : float + Moving image x pixel size (> 0.0). + y : float + Moving image y pixel size (> 0.0). + + Will show an error if the pixel sizes are less than or equal to 0. + Will show an error if the moving image or atlas is not selected. + """ + if x <= 0 or y <= 0: + show_error("Pixel sizes must be greater than 0") + return + + if not (self._moving_image and self._atlas): + show_error( + "Sample image or atlas not selected. " + "Please select a sample image and atlas before scaling", + ) + return + + if self._moving_image_data_backup is None: + self._moving_image_data_backup = self._moving_image.data.copy() + + x_factor = x / self._atlas.resolution[0] + y_factor = y / self._atlas.resolution[1] + + self._moving_image.data = rescale( + self._moving_image_data_backup, + (y_factor, x_factor), + mode="constant", + preserve_range=True, + anti_aliasing=True, + ) + + def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float): + if not ( + self._atlas + and self._atlas_data_layer + and self._atlas_annotations_layer + ): + show_error( + "No atlas selected. Please select an atlas before rotating" + ) + return + + # Create the rotation matrix + roll_matrix = active_matrix_from_angle(0, np.deg2rad(roll)) + pitch_matrix = active_matrix_from_angle(2, np.deg2rad(pitch)) + yaw_matrix = active_matrix_from_angle(1, np.deg2rad(yaw)) + + rot_matrix = roll_matrix @ pitch_matrix @ yaw_matrix + + full_matrix = np.eye(4) + full_matrix[:-1, :-1] = rot_matrix + + # Translate the origin to the center of the image + origin = np.asarray(self._atlas.reference.shape) // 2 + translate_matrix = np.eye(4) + translate_matrix[:-1, -1] = origin + + bounding_box = calculate_rotated_bounding_box( + self._atlas.reference.shape, full_matrix + ) + new_translation = np.asarray(bounding_box) // 2 + post_rotate_translation = np.eye(4) + post_rotate_translation[:3, -1] = -new_translation + + # Combine the matrices. The order of operations is: + # 1. Translate the origin to the center of the image + # 2. Rotate the image + # 3. Translate the origin back to the top left corner + transform_matrix = ( + translate_matrix @ full_matrix @ post_rotate_translation + ) + + self._atlas_data_layer.data = dask_affine_transform( + self._atlas.reference, + transform_matrix, + order=2, + output_shape=bounding_box, + output_chunks=(2, bounding_box[1], bounding_box[2]), + ) + + self._atlas_annotations_layer.data = dask_affine_transform( + self._atlas.annotation, + transform_matrix, + order=0, + output_shape=bounding_box, + output_chunks=(2, bounding_box[1], bounding_box[2]), + ) + + # Resets the viewer grid to update the grid to the new atlas + self._viewer.reset_view() + + worker = self.compute_atlas_rotation(self._atlas_data_layer.data) + worker.returned.connect(self.set_atlas_layer_data) + worker.start() + + @thread_worker + def compute_atlas_rotation(self, dask_array: da.Array): + self.adjust_moving_image_widget.reset_atlas_button.setEnabled(False) + self.adjust_moving_image_widget.adjust_atlas_rotation.setEnabled(False) + + computed_array = dask_array.compute() + + self.adjust_moving_image_widget.reset_atlas_button.setEnabled(True) + self.adjust_moving_image_widget.adjust_atlas_rotation.setEnabled(True) + + return computed_array + + def set_atlas_layer_data(self, new_data): + self._atlas_data_layer.data = new_data + + def _on_atlas_reset(self): + if not self._atlas: + show_error( + "No atlas selected. Please select an atlas before resetting" + ) + return + + self._atlas_data_layer.data = self._atlas.reference + self._atlas_annotations_layer.data = self._atlas.annotation + self._viewer.grid.enabled = False + self._viewer.grid.enabled = True diff --git a/brainglobe_registration/utils/utils.py b/brainglobe_registration/utils/utils.py index 48d0933..858ebab 100644 --- a/brainglobe_registration/utils/utils.py +++ b/brainglobe_registration/utils/utils.py @@ -1,9 +1,15 @@ from pathlib import Path -from typing import List +from typing import Dict, List, Tuple import napari import numpy as np +import numpy.typing as npt +from brainglobe_atlasapi.list_atlases import get_downloaded_atlases +from brainglobe_utils.qtpy.dialog import display_info from pytransform3d.rotations import active_matrix_from_angle +from qtpy.QtWidgets import QWidget +from scipy.ndimage import gaussian_filter +from skimage import morphology def adjust_napari_image_layer( @@ -20,7 +26,7 @@ def adjust_napari_image_layer( https://forum.image.sc/t/napari-3d-rotation-center-change-and-scaling/66347/5 Parameters - ---------- + ------------ image_layer : napari.layers.Layer The napari image layer to be adjusted. x : int @@ -31,7 +37,7 @@ def adjust_napari_image_layer( The angle of rotation in degrees. Returns - ------- + -------- None """ image_layer.translate = (y, x) @@ -46,7 +52,7 @@ def adjust_napari_image_layer( image_layer.affine = transform_matrix -def open_parameter_file(file_path: Path) -> dict: +def open_parameter_file(file_path: Path) -> Dict: """ Opens the parameter file and returns the parameter dictionary. @@ -56,13 +62,13 @@ def open_parameter_file(file_path: Path) -> dict: The values are stripped of any trailing ")" and leading or trailing quotes. Parameters - ---------- + ------------ file_path : Path The path to the parameter file. Returns - ------- - dict + -------- + Dict A dictionary containing the parameters from the file. """ with open(file_path, "r") as f: @@ -94,5 +100,157 @@ def find_layer_index(viewer: napari.Viewer, layer_name: str) -> int: def get_image_layer_names(viewer: napari.Viewer) -> List[str]: """ Returns a list of the names of the napari image layers in the viewer. + + Parameters + ------------ + viewer : napari.Viewer + The napari viewer containing the image layers. + + Returns + -------- + List[str] + A list of the names of the image layers in the viewer. """ return [layer.name for layer in viewer.layers] + + +def calculate_rotated_bounding_box( + image_shape: Tuple[int, int, int], rotation_matrix: npt.NDArray +) -> Tuple[int, int, int]: + """ + Calculates the bounding box of the rotated image. + + This function calculates the bounding box of the rotated image given the + image shape and rotation matrix. The bounding box is calculated by + transforming the corners of the image and finding the minimum and maximum + values of the transformed corners. + + Parameters + ------------ + image_shape : Tuple[int, int, int] + The shape of the image. + rotation_matrix : npt.NDArray + The rotation matrix. + + Returns + -------- + Tuple[int, int, int] + The bounding box of the rotated image. + """ + corners = np.array( + [ + [0, 0, 0, 1], + [image_shape[0], 0, 0, 1], + [0, image_shape[1], 0, 1], + [0, 0, image_shape[2], 1], + [image_shape[0], image_shape[1], 0, 1], + [image_shape[0], 0, image_shape[2], 1], + [0, image_shape[1], image_shape[2], 1], + [image_shape[0], image_shape[1], image_shape[2], 1], + ] + ) + + transformed_corners = rotation_matrix @ corners.T + min_corner = np.min(transformed_corners, axis=1) + max_corner = np.max(transformed_corners, axis=1) + + return ( + int(np.round(max_corner[0] - min_corner[0])), + int(np.round(max_corner[1] - min_corner[1])), + int(np.round(max_corner[2] - min_corner[2])), + ) + + +def check_atlas_installed(parent_widget: QWidget): + """ + Function checks if user has any atlases installed. If not, message box + appears in napari, directing user to download atlases via attached links. + """ + available_atlases = get_downloaded_atlases() + if len(available_atlases) == 0: + display_info( + widget=parent_widget, + title="Information", + message="No atlases available. Please download atlas(es) " + "using " + "brainglobe-atlasapi or brainrender-napari", + ) + + +def filter_plane(img_plane): + """ + Apply a set of filter to the plane (typically to avoid overfitting details + in the image during registration) + The filter is composed of a despeckle filter using opening and a pseudo + flatfield filter + + Originally from: [https://github.com/brainglobe/brainreg/blob/main + /brainreg/core/utils/preprocess.py] + + Parameters + ---------- + img_plane : np.array + A 2D array to filter + + Returns + ------- + np.array + Filtered image + """ + + img_plane = despeckle_by_opening(img_plane) + img_plane = pseudo_flatfield(img_plane) + return img_plane + + +def despeckle_by_opening(img_plane, radius=2): + """ + Despeckle the image plane using a grayscale opening operation + + Originally from: [https://github.com/brainglobe/brainreg/blob/main + /brainreg/core/utils/preprocess.py] + + Parameters + ---------- + img_plane : np.array + The image to filter + + radius: int + The radius of the opening kernel + + Returns + ------- + np.array + The despeckled image + """ + kernel = morphology.disk(radius) + morphology.opening(img_plane, out=img_plane, footprint=kernel) + return img_plane + + +def pseudo_flatfield(img_plane, sigma=5): + """ + Pseudo flat field filter implementation using a de-trending by a + heavily gaussian filtered copy of the image. + + Originally from: [https://github.com/brainglobe/brainreg/blob/main + /brainreg/core/utils/preprocess.py] + + Parameters + ---------- + img_plane : np.array + The image to filter + + sigma : int + The sigma of the gaussian filter applied to the + image used for de-trending + + Returns + ------- + np.array + The pseudo flat field filtered image + """ + filtered_img = gaussian_filter(img_plane, sigma) + return img_plane / (filtered_img + 1) diff --git a/brainglobe_registration/widgets/adjust_moving_image_view.py b/brainglobe_registration/widgets/adjust_moving_image_view.py index 4740e25..111951d 100644 --- a/brainglobe_registration/widgets/adjust_moving_image_view.py +++ b/brainglobe_registration/widgets/adjust_moving_image_view.py @@ -1,4 +1,4 @@ -from qtpy.QtCore import Signal +from qtpy.QtCore import Qt, Signal from qtpy.QtWidgets import ( QDoubleSpinBox, QFormLayout, @@ -14,8 +14,8 @@ class AdjustMovingImageView(QWidget): A QWidget subclass that provides controls for adjusting the moving image. This widget provides controls for adjusting the x and y offsets and - rotation of the moving image. It emits signals when the image is adjusted - or reset. + rotation of the moving image. It emits signals when the image is adjusted, + scaled, or reset. Attributes ---------- @@ -33,9 +33,18 @@ class AdjustMovingImageView(QWidget): _on_reset_image_button_click(): Resets the x and y offsets and rotation to 0 and emits the reset_image_signal. + _on_scale_image_button_click(): + Emits the scale_image_signal with the entered pixel sizes. + _on_adjust_atlas_rotation(): + Emits the atlas_rotation_signal with the entered pitch, yaw, and roll. + _on_atlas_reset(): + Resets the pitch, yaw, and roll to 0 and emits the atlas_reset_signal. """ adjust_image_signal = Signal(int, int, float) + scale_image_signal = Signal(float, float) + atlas_rotation_signal = Signal(float, float, float) + reset_atlas_signal = Signal() reset_image_signal = Signal() def __init__(self, parent=None): @@ -50,19 +59,53 @@ def __init__(self, parent=None): super().__init__(parent=parent) self.setLayout(QFormLayout()) + self.layout().setLabelAlignment(Qt.AlignLeft) offset_range = 2000 rotation_range = 360 - self.adjust_moving_image_x = QSpinBox() + self.adjust_moving_image_pixel_size_x = QDoubleSpinBox(parent=self) + self.adjust_moving_image_pixel_size_x.setDecimals(3) + self.adjust_moving_image_pixel_size_x.setRange(0.001, 100.00) + self.adjust_moving_image_pixel_size_y = QDoubleSpinBox(parent=self) + self.adjust_moving_image_pixel_size_y.setDecimals(3) + self.adjust_moving_image_pixel_size_y.setRange(0.001, 100.00) + self.scale_moving_image_button = QPushButton() + self.scale_moving_image_button.setText("Scale Image") + self.scale_moving_image_button.clicked.connect( + self._on_scale_image_button_click + ) + + self.adjust_atlas_pitch = QDoubleSpinBox(parent=self) + self.adjust_atlas_pitch.setSingleStep(0.1) + self.adjust_atlas_pitch.setRange(-rotation_range, rotation_range) + + self.adjust_atlas_yaw = QDoubleSpinBox(parent=self) + self.adjust_atlas_yaw.setSingleStep(0.1) + self.adjust_atlas_yaw.setRange(-rotation_range, rotation_range) + + self.adjust_atlas_roll = QDoubleSpinBox(parent=self) + self.adjust_atlas_roll.setSingleStep(0.1) + self.adjust_atlas_roll.setRange(-rotation_range, rotation_range) + + self.adjust_atlas_rotation = QPushButton() + self.adjust_atlas_rotation.setText("Rotate Atlas") + self.adjust_atlas_rotation.clicked.connect( + self._on_adjust_atlas_rotation + ) + self.reset_atlas_button = QPushButton() + self.reset_atlas_button.setText("Reset Atlas") + self.reset_atlas_button.clicked.connect(self._on_atlas_reset) + + self.adjust_moving_image_x = QSpinBox(parent=self) self.adjust_moving_image_x.setRange(-offset_range, offset_range) self.adjust_moving_image_x.valueChanged.connect(self._on_adjust_image) - self.adjust_moving_image_y = QSpinBox() + self.adjust_moving_image_y = QSpinBox(parent=self) self.adjust_moving_image_y.setRange(-offset_range, offset_range) self.adjust_moving_image_y.valueChanged.connect(self._on_adjust_image) - self.adjust_moving_image_rotate = QDoubleSpinBox() + self.adjust_moving_image_rotate = QDoubleSpinBox(parent=self) self.adjust_moving_image_rotate.setRange( -rotation_range, rotation_range ) @@ -71,13 +114,31 @@ def __init__(self, parent=None): self._on_adjust_image ) - self.adjust_moving_image_reset_button = QPushButton(parent=self) + self.adjust_moving_image_reset_button = QPushButton() self.adjust_moving_image_reset_button.setText("Reset Image") self.adjust_moving_image_reset_button.clicked.connect( self._on_reset_image_button_click ) - self.layout().addRow(QLabel("Adjust the moving image: ")) + self.layout().addRow(QLabel("Adjust the moving image scale:")) + self.layout().addRow( + "Sample image X pixel size (\u03BCm / pixel):", + self.adjust_moving_image_pixel_size_x, + ) + self.layout().addRow( + "Sample image Y pixel size (\u03BCm / pixel):", + self.adjust_moving_image_pixel_size_y, + ) + self.layout().addRow(self.scale_moving_image_button) + + self.layout().addRow(QLabel("Adjust the atlas pitch and yaw: ")) + self.layout().addRow("Pitch:", self.adjust_atlas_pitch) + self.layout().addRow("Yaw:", self.adjust_atlas_yaw) + self.layout().addRow("Roll:", self.adjust_atlas_roll) + self.layout().addRow(self.adjust_atlas_rotation) + self.layout().addRow(self.reset_atlas_button) + + self.layout().addRow(QLabel("Adjust the moving image position: ")) self.layout().addRow("X offset:", self.adjust_moving_image_x) self.layout().addRow("Y offset:", self.adjust_moving_image_y) self.layout().addRow( @@ -106,3 +167,32 @@ def _on_reset_image_button_click(self): self.adjust_moving_image_rotate.setValue(0) self.reset_image_signal.emit() + + def _on_scale_image_button_click(self): + """ + Emit the scale_image_signal with the entered pixel sizes. + """ + self.scale_image_signal.emit( + self.adjust_moving_image_pixel_size_x.value(), + self.adjust_moving_image_pixel_size_y.value(), + ) + + def _on_adjust_atlas_rotation(self): + """ + Emit the atlas_rotation_signal with the entered pitch and yaw. + """ + self.atlas_rotation_signal.emit( + self.adjust_atlas_pitch.value(), + self.adjust_atlas_yaw.value(), + self.adjust_atlas_roll.value(), + ) + + def _on_atlas_reset(self): + """ + Reset the pitch, yaw, and roll to 0 and emit the atlas_reset_signal. + """ + self.adjust_atlas_yaw.setValue(0) + self.adjust_atlas_pitch.setValue(0) + self.adjust_atlas_roll.setValue(0) + + self.reset_atlas_signal.emit() diff --git a/brainglobe_registration/widgets/select_images_view.py b/brainglobe_registration/widgets/select_images_view.py index 17bc737..ad9046c 100644 --- a/brainglobe_registration/widgets/select_images_view.py +++ b/brainglobe_registration/widgets/select_images_view.py @@ -98,7 +98,8 @@ def __init__( def update_sample_image_names(self, sample_image_names: List[str]): self.available_sample_dropdown.clear() - self.available_sample_dropdown.addItems(sample_image_names) + if len(sample_image_names) != 0: + self.available_sample_dropdown.addItems(sample_image_names) def _on_atlas_index_change(self): """ @@ -122,5 +123,9 @@ def _on_moving_image_index_change(self): self.available_sample_dropdown.currentIndex() ) + def reset_atlas_combobox(self): + if self.available_atlas_dropdown is not None: + self.available_atlas_dropdown.setCurrentIndex(0) + def _on_sample_popup_about_to_show(self): self.sample_image_popup_about_to_show.emit() diff --git a/brainglobe_registration/widgets/transform_select_view.py b/brainglobe_registration/widgets/transform_select_view.py index eba8c32..d141816 100644 --- a/brainglobe_registration/widgets/transform_select_view.py +++ b/brainglobe_registration/widgets/transform_select_view.py @@ -169,7 +169,7 @@ def _on_transform_type_change(self, index): self.file_selections[index].setCurrentIndex(0) if index >= len(self.transform_type_selections) - 1: - curr_length = self.rowCount() + current_length = self.rowCount() self.setRowCount(self.rowCount() + 1) self.transform_type_selections.append(QComboBox()) @@ -180,7 +180,7 @@ def _on_transform_type_change(self, index): self.transform_type_signaller.map ) self.transform_type_signaller.setMapping( - self.transform_type_selections[-1], curr_length + self.transform_type_selections[-1], current_length ) self.file_selections.append(QComboBox()) @@ -191,14 +191,16 @@ def _on_transform_type_change(self, index): self.file_signaller.map ) self.file_signaller.setMapping( - self.file_selections[-1], curr_length + self.file_selections[-1], current_length ) self.setCellWidget( - curr_length, 0, self.transform_type_selections[curr_length] + current_length, + 0, + self.transform_type_selections[current_length], ) self.setCellWidget( - curr_length, 1, self.file_selections[curr_length] + current_length, 1, self.file_selections[current_length] ) else: diff --git a/pyproject.toml b/pyproject.toml index 9317c22..0d14999 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "napari>=0.4.18", "brainglobe-atlasapi", "brainglobe-utils>=0.4.3", + "dask", + "dask-image", "itk-elastix", "lxml_html_clean", "numpy", diff --git a/tests/test_adjust_moving_image_view.py b/tests/test_adjust_moving_image_view.py index e56107e..8eae56c 100644 --- a/tests/test_adjust_moving_image_view.py +++ b/tests/test_adjust_moving_image_view.py @@ -14,6 +14,12 @@ def adjust_moving_image_view() -> AdjustMovingImageView: return adjust_moving_image_view +def test_init(qtbot, adjust_moving_image_view): + qtbot.addWidget(adjust_moving_image_view) + + assert adjust_moving_image_view.layout().rowCount() == 15 + + @pytest.mark.parametrize( "x_value, expected", [ @@ -94,3 +100,62 @@ def test_reset_image_button_click(qtbot, adjust_moving_image_view): assert adjust_moving_image_view.adjust_moving_image_x.value() == 0 assert adjust_moving_image_view.adjust_moving_image_y.value() == 0 assert adjust_moving_image_view.adjust_moving_image_rotate.value() == 0 + + +@pytest.mark.parametrize( + "x_scale, y_scale", + [(2.5, 2.5), (10, 20), (10.2212, 10.2289)], +) +def test_scale_image_button_click( + qtbot, adjust_moving_image_view, x_scale, y_scale +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.scale_image_signal, timeout=1000 + ) as blocker: + adjust_moving_image_view.adjust_moving_image_pixel_size_x.setValue( + x_scale + ) + adjust_moving_image_view.adjust_moving_image_pixel_size_y.setValue( + y_scale + ) + adjust_moving_image_view.scale_moving_image_button.click() + + assert blocker.args == [round(x_scale, 3), round(y_scale, 3)] + + +def test_atlas_rotation_changed( + qtbot, + adjust_moving_image_view, +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.atlas_rotation_signal, timeout=1000 + ) as blocker: + adjust_moving_image_view.adjust_atlas_pitch.setValue(10) + adjust_moving_image_view.adjust_atlas_yaw.setValue(20) + adjust_moving_image_view.adjust_atlas_roll.setValue(30) + + adjust_moving_image_view.adjust_atlas_rotation.click() + + assert blocker.args == [10, 20, 30] + + +def test_atlas_reset_button_click( + qtbot, + adjust_moving_image_view, +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.reset_atlas_signal, timeout=1000 + ): + adjust_moving_image_view.reset_atlas_button.click() + + assert ( + adjust_moving_image_view.adjust_atlas_pitch.value() == 0 + and adjust_moving_image_view.adjust_atlas_yaw.value() == 0 + and adjust_moving_image_view.adjust_atlas_roll.value() == 0 + ) diff --git a/tests/test_parameter_list_view.py b/tests/test_parameter_list_view.py index 00a701f..cabbfce 100644 --- a/tests/test_parameter_list_view.py +++ b/tests/test_parameter_list_view.py @@ -56,7 +56,7 @@ def test_parameter_list_view_cell_change(parameter_list_view, qtbot): def test_parameter_list_view_cell_change_last_row(parameter_list_view, qtbot): qtbot.addWidget(parameter_list_view) - curr_row_count = parameter_list_view.rowCount() + current_row_count = parameter_list_view.rowCount() last_row_index = len(param_dict) parameter_list_view.setItem( @@ -65,7 +65,7 @@ def test_parameter_list_view_cell_change_last_row(parameter_list_view, qtbot): parameter_list_view.setItem(last_row_index, 1, QTableWidgetItem("true")) assert parameter_list_view.param_dict["TestParameter"] == ["true"] - assert parameter_list_view.rowCount() == curr_row_count + 1 + assert parameter_list_view.rowCount() == current_row_count + 1 def test_parameter_list_view_cell_change_last_row_no_param( @@ -73,12 +73,12 @@ def test_parameter_list_view_cell_change_last_row_no_param( ): qtbot.addWidget(parameter_list_view) - curr_row_count = parameter_list_view.rowCount() + current_row_count = parameter_list_view.rowCount() last_row_index = len(param_dict) parameter_list_view.setItem(last_row_index, 1, QTableWidgetItem("true")) - assert parameter_list_view.rowCount() == curr_row_count + assert parameter_list_view.rowCount() == current_row_count def test_parameter_list_view_cell_change_last_row_no_value( @@ -86,11 +86,11 @@ def test_parameter_list_view_cell_change_last_row_no_value( ): qtbot.addWidget(parameter_list_view) - curr_row_count = parameter_list_view.rowCount() + current_row_count = parameter_list_view.rowCount() last_row_index = len(param_dict) parameter_list_view.setItem( last_row_index, 0, QTableWidgetItem("TestParameter") ) - assert parameter_list_view.rowCount() == curr_row_count + assert parameter_list_view.rowCount() == current_row_count diff --git a/tests/test_registration_widget.py b/tests/test_registration_widget.py index aa91594..828c4aa 100644 --- a/tests/test_registration_widget.py +++ b/tests/test_registration_widget.py @@ -12,6 +12,29 @@ def registration_widget(make_napari_viewer_with_images): return widget +@pytest.fixture() +def registration_widget_with_example_atlas(make_napari_viewer_with_images): + """ + Create an initialised RegistrationWidget with the "example_mouse_100um" + loaded. + + Parameters + ------------ + make_napari_viewer_with_images for testing + Fixture that creates a napari viewer + """ + viewer = make_napari_viewer_with_images + + widget = RegistrationWidget(viewer) + + # Based on the downloaded atlases by the fixture in conftest.py, + # the example atlas will be in the third position. + example_atlas_index = 2 + widget._on_atlas_dropdown_index_changed(example_atlas_index) + + return widget + + def test_registration_widget(make_napari_viewer_with_images): widget = RegistrationWidget(make_napari_viewer_with_images) @@ -65,3 +88,129 @@ def test_sample_dropdown_index_changed_with_valid_index( registration_widget._moving_image.name == registration_widget._sample_images[1] ) + + +def test_scale_moving_image_no_atlas( + make_napari_viewer_with_images, registration_widget, mocker +): + mocked_show_error = mocker.patch( + "brainglobe_registration.registration_widget.show_error" + ) + registration_widget._atlas = None + registration_widget.adjust_moving_image_widget.scale_image_signal.emit( + 10, 10 + ) + mocked_show_error.assert_called_once_with( + "Sample image or atlas not selected. " + "Please select a sample image and atlas before scaling" + ) + + +def test_scale_moving_image_no_sample_image( + make_napari_viewer_with_images, registration_widget, mocker +): + mocked_show_error = mocker.patch( + "brainglobe_registration.registration_widget.show_error" + ) + registration_widget._moving_image = None + registration_widget.adjust_moving_image_widget.scale_image_signal.emit( + 10, 10 + ) + mocked_show_error.assert_called_once_with( + "Sample image or atlas not selected. " + "Please select a sample image and atlas before scaling" + ) + + +@pytest.mark.parametrize( + "x_scale_factor, y_scale_factor", + [ + (0.5, 0.5), + (1.0, 1.0), + (2.0, 2.0), + (0.5, 1.0), + (1.0, 0.5), + ], +) +def test_scale_moving_image( + make_napari_viewer_with_images, + registration_widget, + mocker, + x_scale_factor, + y_scale_factor, +): + mock_atlas = mocker.patch( + "brainglobe_registration.registration_widget.BrainGlobeAtlas" + ) + mock_atlas.resolution = [20, 20, 20] + registration_widget._atlas = mock_atlas + + current_size = registration_widget._moving_image.data.shape + registration_widget.adjust_moving_image_widget.scale_image_signal.emit( + mock_atlas.resolution[1] * x_scale_factor, + mock_atlas.resolution[2] * y_scale_factor, + ) + + assert registration_widget._moving_image.data.shape == ( + current_size[0] * y_scale_factor, + current_size[1] * x_scale_factor, + ) + + +@pytest.mark.parametrize( + "pitch, yaw, roll, expected_shape", + [ + (0, 0, 0, (132, 80, 114)), + (45, 0, 0, (150, 150, 114)), + (0, 45, 0, (174, 80, 174)), + (0, 0, 45, (132, 137, 137)), + ], +) +def test_on_adjust_atlas_rotation( + registration_widget_with_example_atlas, + pitch, + yaw, + roll, + expected_shape, +): + reg_widget = registration_widget_with_example_atlas + atlas_shape = reg_widget._atlas.reference.shape + + reg_widget._on_adjust_atlas_rotation(pitch, yaw, roll) + + assert reg_widget._atlas_data_layer.data.shape == expected_shape + assert reg_widget._atlas_annotations_layer.data.shape == expected_shape + assert reg_widget._atlas.reference.shape == atlas_shape + + +def test_on_adjust_atlas_rotation_no_atlas(registration_widget, mocker): + mocked_show_error = mocker.patch( + "brainglobe_registration.registration_widget.show_error" + ) + registration_widget._on_adjust_atlas_rotation(10, 10, 10) + mocked_show_error.assert_called_once_with( + "No atlas selected. Please select an atlas before rotating" + ) + + +def test_on_atlas_reset(registration_widget_with_example_atlas): + reg_widget = registration_widget_with_example_atlas + atlas_shape = reg_widget._atlas.reference.shape + reg_widget._on_adjust_atlas_rotation(10, 10, 10) + + reg_widget._on_atlas_reset() + + assert reg_widget._atlas_data_layer.data.shape == atlas_shape + assert reg_widget._atlas.reference.shape == atlas_shape + assert reg_widget._atlas_annotations_layer.data.shape == atlas_shape + + +def test_on_atlas_reset_no_atlas(registration_widget, mocker): + mocked_show_error = mocker.patch( + "brainglobe_registration.registration_widget.show_error" + ) + + registration_widget._on_atlas_reset() + mocked_show_error.assert_called_once_with( + "No atlas selected. Please select an atlas before resetting" + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index d857e09..89068e7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,7 @@ from brainglobe_registration.utils.utils import ( adjust_napari_image_layer, + calculate_rotated_bounding_box, find_layer_index, get_image_layer_names, open_parameter_file, @@ -93,3 +94,32 @@ def test_get_image_layer_names(make_napari_viewer_with_images): assert len(layer_names) == 2 assert layer_names == ["moving_image", "atlas_image"] + + +@pytest.mark.parametrize( + "basis, rotation, expected_bounds", + [ + (0, 0, (50, 100, 200)), + (0, 90, (50, 200, 100)), + (0, 180, (50, 100, 200)), + (0, 45, (50, 212, 212)), + (1, 0, (50, 100, 200)), + (1, 90, (200, 100, 50)), + (1, 180, (50, 100, 200)), + (1, 45, (177, 100, 177)), + (2, 0, (50, 100, 200)), + (2, 90, (100, 50, 200)), + (2, 180, (50, 100, 200)), + (2, 45, (106, 106, 200)), + ], +) +def test_calculate_rotated_bounding_box(basis, rotation, expected_bounds): + image_shape = (50, 100, 200) + rotation_matrix = np.eye(4) + rotation_matrix[:-1, :-1] = active_matrix_from_angle( + basis, np.deg2rad(rotation) + ) + + result_shape = calculate_rotated_bounding_box(image_shape, rotation_matrix) + + assert result_shape == expected_bounds