diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index c00e554..de1d2d5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -47,10 +47,28 @@ jobs: with: qt: true + # cache atlases needed by the tests + - name: Cache Atlases + id: atlas-cache + uses: actions/cache@v3 + with: + path: | # ensure we don't cache any interrupted atlas download and extraction! + ~/.brainglobe/* + !~/.brainglobe/atlas.tar.gz + key: ${{ runner.os }}-cached-atlases + enableCrossOsArchive: false # ~ and $HOME evaluate to different places across OSs! + + - if: ${{ steps.atlas-cache.outputs.cache-hit == 'true' }} + name: List files in brainglobe data folder # good to be able to sanity check that user data is as expected + run: | + ls -af ~/.brainglobe/ + + # Run tests - uses: neuroinformatics-unit/actions/test@v2 with: python-version: ${{ matrix.python-version }} + secret-codecov-token: ${{ secrets.CODECOV_TOKEN }} build_sdist_wheels: name: Build source distribution @@ -66,11 +84,6 @@ jobs: needs: [build_sdist_wheels] runs-on: ubuntu-latest steps: - - uses: actions/download-artifact@v3 - with: - name: artifact - path: dist - - uses: pypa/gh-action-pypi-publish@v1.5.0 + - uses: neuroinformatics-unit/actions/upload_pypi@v2 with: - user: __token__ - password: ${{ secrets.TWINE_API_KEY }} + secret-pypi-key: ${{ secrets.TWINE_API_KEY }} diff --git a/.gitignore b/.gitignore index 73d56d3..08e97e3 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ venv/ # written by setuptools_scm **/_version.py + +# Elastix related files +/Logs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ffb785..f719a01 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,15 +17,15 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.9 + rev: v0.3.5 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.3.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.9.0 hooks: - id: mypy additional_dependencies: diff --git a/MANIFEST.in b/MANIFEST.in index b0e8711..902dcc4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,7 +3,6 @@ include README.md include .napari-hub/DESCRIPTION.md include .napari-hub/config.yml include brainglobe_registration/napari.yaml -include brainglobe_registration/resources/brainglobe.png exclude .pre-commit-config.yaml recursive-include brainglobe_registration/parameters *.txt diff --git a/brainglobe_registration/elastix/register.py b/brainglobe_registration/elastix/register.py index 955ed44..d04d918 100644 --- a/brainglobe_registration/elastix/register.py +++ b/brainglobe_registration/elastix/register.py @@ -2,7 +2,7 @@ import itk import numpy as np -from bg_atlasapi import BrainGlobeAtlas +from brainglobe_atlasapi import BrainGlobeAtlas def get_atlas_by_name(atlas_name: str) -> BrainGlobeAtlas: diff --git a/brainglobe_registration/registration_widget.py b/brainglobe_registration/registration_widget.py index 67ab52f..35016fd 100644 --- a/brainglobe_registration/registration_widget.py +++ b/brainglobe_registration/registration_widget.py @@ -9,13 +9,21 @@ """ from pathlib import Path +from typing import Optional +import dask.array as da +import napari.layers import numpy as np -from bg_atlasapi import BrainGlobeAtlas -from bg_atlasapi.list_atlases import get_downloaded_atlases +import numpy.typing as npt +from brainglobe_atlasapi import BrainGlobeAtlas +from brainglobe_atlasapi.list_atlases import get_downloaded_atlases from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer +from brainglobe_utils.qtpy.logo import header_widget +from dask_image.ndinterp import affine_transform as dask_affine_transform +from napari.qt.threading import thread_worker from napari.utils.notifications import show_error from napari.viewer import Viewer +from pytransform3d.rotations import active_matrix_from_angle from qtpy.QtWidgets import ( QPushButton, QTabWidget, @@ -24,9 +32,9 @@ from skimage.transform import rescale from brainglobe_registration.elastix.register import run_registration -from brainglobe_registration.utils.brainglobe_logo import header_widget from brainglobe_registration.utils.utils import ( adjust_napari_image_layer, + calculate_rotated_bounding_box, find_layer_index, get_image_layer_names, open_parameter_file, @@ -49,9 +57,11 @@ def __init__(self, napari_viewer: Viewer): self.setContentsMargins(10, 10, 10, 10) self._viewer = napari_viewer - self._atlas: BrainGlobeAtlas = None - self._moving_image = None - self._moving_image_data_backup = None + self._atlas: Optional[BrainGlobeAtlas] = None + self._atlas_data_layer: Optional[napari.layers.Image] = None + self._atlas_annotations_layer: Optional[napari.layers.Labels] = None + self._moving_image: Optional[napari.layers.Image] = None + self._moving_image_data_backup: Optional[npt.NDArray] = None self.transform_params: dict[str, dict] = { "rigid": {}, @@ -110,6 +120,12 @@ def __init__(self, napari_viewer: Viewer): self.adjust_moving_image_widget.scale_image_signal.connect( self._on_scale_moving_image ) + self.adjust_moving_image_widget.atlas_rotation_signal.connect( + self._on_adjust_atlas_rotation + ) + self.adjust_moving_image_widget.reset_atlas_signal.connect( + self._on_atlas_reset + ) self.transform_select_view = TransformSelectView() self.transform_select_view.transform_type_added_signal.connect( @@ -126,10 +142,17 @@ def __init__(self, napari_viewer: Viewer): self.run_button.clicked.connect(self._on_run_button_click) self.run_button.setEnabled(False) - self.add_widget(header_widget(), collapsible=False) + self.add_widget( + header_widget( + "brainglobe-
registration", # line break at
+ "Registration with Elastix", + github_repo_name="brainglobe-registration", + ), + collapsible=False, + ) self.add_widget(self.get_atlas_widget, widget_title="Select Images") self.add_widget( - self.adjust_moving_image_widget, widget_title="Adjust Sample Image" + self.adjust_moving_image_widget, widget_title="Prepare Images" ) self.add_widget( self.transform_select_view, widget_title="Select Transformations" @@ -162,13 +185,14 @@ def _on_atlas_dropdown_index_changed(self, index): self._viewer.layers.pop(current_atlas_layer_index) self._atlas = None + self._atlas_data_layer = None + self._atlas_annotations_layer = None self.run_button.setEnabled(False) self._viewer.grid.enabled = False return atlas_name = self._available_atlases[index] - atlas = BrainGlobeAtlas(atlas_name) if self._atlas: current_atlas_layer_index = find_layer_index( @@ -179,14 +203,41 @@ def _on_atlas_dropdown_index_changed(self, index): else: self.run_button.setEnabled(True) - self._viewer.add_image( - atlas.reference, + self._atlas = BrainGlobeAtlas(atlas_name) + dask_reference = da.from_array( + self._atlas.reference, + chunks=( + 1, + self._atlas.reference.shape[1], + self._atlas.reference.shape[2], + ), + ) + dask_annotations = da.from_array( + self._atlas.annotation, + chunks=( + 1, + self._atlas.annotation.shape[1], + self._atlas.annotation.shape[2], + ), + ) + + contrast_max = np.max( + dask_reference[dask_reference.shape[0] // 2] + ).compute() + self._atlas_data_layer = self._viewer.add_image( + dask_reference, name=atlas_name, colormap="gray", blending="translucent", + contrast_limits=[0, contrast_max], + multiscale=False, + ) + self._atlas_annotations_layer = self._viewer.add_labels( + dask_annotations, + name="Annotations", + visible=False, ) - self._atlas = BrainGlobeAtlas(atlas_name=atlas_name) self._viewer.grid.enabled = True def _on_sample_dropdown_index_changed(self, index): @@ -197,18 +248,31 @@ def _on_sample_dropdown_index_changed(self, index): self._moving_image_data_backup = self._moving_image.data.copy() def _on_adjust_moving_image(self, x: int, y: int, rotate: float): + if not self._moving_image: + show_error( + "No moving image selected. " + "Please select a moving image before adjusting" + ) + return adjust_napari_image_layer(self._moving_image, x, y, rotate) def _on_adjust_moving_image_reset_button_click(self): + if not self._moving_image: + show_error( + "No moving image selected. " + "Please select a moving image before adjusting" + ) + return adjust_napari_image_layer(self._moving_image, 0, 0, 0) def _on_run_button_click(self): + current_atlas_slice = self._viewer.dims.current_step[0] result, parameters, registered_annotation_image = run_registration( - self._atlas.reference[current_atlas_slice, :, :], + self._atlas_data_layer.data[current_atlas_slice, :, :], self._moving_image.data, - self._atlas.annotation[current_atlas_slice, :, :], + self._atlas_annotations_layer.data[current_atlas_slice, :, :], self.transform_selections, ) @@ -311,7 +375,7 @@ def _on_scale_moving_image(self, x: float, y: float): Scale the moving image to have resolution equal to the atlas. Parameters - ---------- + ------------ x : float Moving image x pixel size (> 0.0). y : float @@ -324,22 +388,114 @@ def _on_scale_moving_image(self, x: float, y: float): show_error("Pixel sizes must be greater than 0") return - if self._moving_image and self._atlas: - if self._moving_image_data_backup is None: - self._moving_image_data_backup = self._moving_image.data.copy() + if not (self._moving_image and self._atlas): + show_error( + "Sample image or atlas not selected. " + "Please select a sample image and atlas before scaling", + ) + return + + if self._moving_image_data_backup is None: + self._moving_image_data_backup = self._moving_image.data.copy() + + x_factor = x / self._atlas.resolution[0] + y_factor = y / self._atlas.resolution[1] - x_factor = x / self._atlas.resolution[0] - y_factor = y / self._atlas.resolution[1] + self._moving_image.data = rescale( + self._moving_image_data_backup, + (y_factor, x_factor), + mode="constant", + preserve_range=True, + anti_aliasing=True, + ) - self._moving_image.data = rescale( - self._moving_image_data_backup, - (y_factor, x_factor), - mode="constant", - preserve_range=True, - anti_aliasing=True, + def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float): + if not ( + self._atlas + and self._atlas_data_layer + and self._atlas_annotations_layer + ): + show_error( + "No atlas selected. Please select an atlas before rotating" ) - else: + return + + # Create the rotation matrix + roll_matrix = active_matrix_from_angle(0, np.deg2rad(roll)) + pitch_matrix = active_matrix_from_angle(2, np.deg2rad(pitch)) + yaw_matrix = active_matrix_from_angle(1, np.deg2rad(yaw)) + + rot_matrix = roll_matrix @ pitch_matrix @ yaw_matrix + + full_matrix = np.eye(4) + full_matrix[:-1, :-1] = rot_matrix + + # Translate the origin to the center of the image + origin = np.asarray(self._atlas.reference.shape) // 2 + translate_matrix = np.eye(4) + translate_matrix[:-1, -1] = origin + + bounding_box = calculate_rotated_bounding_box( + self._atlas.reference.shape, full_matrix + ) + new_translation = np.asarray(bounding_box) // 2 + post_rotate_translation = np.eye(4) + post_rotate_translation[:3, -1] = -new_translation + + # Combine the matrices. The order of operations is: + # 1. Translate the origin to the center of the image + # 2. Rotate the image + # 3. Translate the origin back to the top left corner + transform_matrix = ( + translate_matrix @ full_matrix @ post_rotate_translation + ) + + self._atlas_data_layer.data = dask_affine_transform( + self._atlas.reference, + transform_matrix, + order=2, + output_shape=bounding_box, + output_chunks=(2, bounding_box[1], bounding_box[2]), + ) + + self._atlas_annotations_layer.data = dask_affine_transform( + self._atlas.annotation, + transform_matrix, + order=0, + output_shape=bounding_box, + output_chunks=(2, bounding_box[1], bounding_box[2]), + ) + + # Resets the viewer grid to update the grid to the new atlas + self._viewer.reset_view() + + worker = self.compute_atlas_rotation(self._atlas_data_layer.data) + worker.returned.connect(self.set_atlas_layer_data) + worker.start() + + @thread_worker + def compute_atlas_rotation(self, dask_array: da.Array): + self.adjust_moving_image_widget.reset_atlas_button.setEnabled(False) + self.adjust_moving_image_widget.adjust_atlas_rotation.setEnabled(False) + + computed_array = dask_array.compute() + + self.adjust_moving_image_widget.reset_atlas_button.setEnabled(True) + self.adjust_moving_image_widget.adjust_atlas_rotation.setEnabled(True) + + return computed_array + + def set_atlas_layer_data(self, new_data): + self._atlas_data_layer.data = new_data + + def _on_atlas_reset(self): + if not self._atlas: show_error( - "Sample image or atlas not selected. " - "Please select a sample image and atlas before scaling", + "No atlas selected. Please select an atlas before resetting" ) + return + + self._atlas_data_layer.data = self._atlas.reference + self._atlas_annotations_layer.data = self._atlas.annotation + self._viewer.grid.enabled = False + self._viewer.grid.enabled = True diff --git a/brainglobe_registration/resources/brainglobe.png b/brainglobe_registration/resources/brainglobe.png deleted file mode 100644 index 427bdab..0000000 Binary files a/brainglobe_registration/resources/brainglobe.png and /dev/null differ diff --git a/brainglobe_registration/utils/brainglobe_logo.py b/brainglobe_registration/utils/brainglobe_logo.py deleted file mode 100644 index e387adb..0000000 --- a/brainglobe_registration/utils/brainglobe_logo.py +++ /dev/null @@ -1,46 +0,0 @@ -from importlib.resources import files - -from qtpy.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QVBoxLayout, QWidget - -brainglobe_logo = files("brainglobe_registration").joinpath( - "resources/brainglobe.png" -) - -_logo_html = f""" -

- -<\h1> -""" - - -def _docs_links_widget(parent: QWidget = None): - _docs_links_html = """ -

-

Website

-

Source

-

- """ # noqa: E501 - docs_links_widget = QLabel(_docs_links_html, parent=parent) - docs_links_widget.setOpenExternalLinks(True) - return docs_links_widget - - -def _logo_widget(parent: QWidget = None): - return QLabel(_logo_html, parent=None) - - -def header_widget(parent: QWidget = None): - box = QGroupBox(parent) - box.setFlat(True) - box.setLayout(QVBoxLayout()) - box.layout().addWidget(QLabel("

brainglobe-registration

")) - subbox = QGroupBox(parent) - subbox.setFlat(True) - subbox.setLayout(QHBoxLayout()) - subbox.layout().setSpacing(0) - subbox.layout().setContentsMargins(0, 0, 0, 0) - subbox.setStyleSheet("border: 0;") - subbox.layout().addWidget(_logo_widget(parent=box)) - subbox.layout().addWidget(_docs_links_widget(parent=box)) - box.layout().addWidget(subbox) - return box diff --git a/brainglobe_registration/utils/utils.py b/brainglobe_registration/utils/utils.py index 48d0933..2bb9891 100644 --- a/brainglobe_registration/utils/utils.py +++ b/brainglobe_registration/utils/utils.py @@ -1,8 +1,9 @@ from pathlib import Path -from typing import List +from typing import Dict, List, Tuple import napari import numpy as np +import numpy.typing as npt from pytransform3d.rotations import active_matrix_from_angle @@ -20,7 +21,7 @@ def adjust_napari_image_layer( https://forum.image.sc/t/napari-3d-rotation-center-change-and-scaling/66347/5 Parameters - ---------- + ------------ image_layer : napari.layers.Layer The napari image layer to be adjusted. x : int @@ -31,7 +32,7 @@ def adjust_napari_image_layer( The angle of rotation in degrees. Returns - ------- + -------- None """ image_layer.translate = (y, x) @@ -46,7 +47,7 @@ def adjust_napari_image_layer( image_layer.affine = transform_matrix -def open_parameter_file(file_path: Path) -> dict: +def open_parameter_file(file_path: Path) -> Dict: """ Opens the parameter file and returns the parameter dictionary. @@ -56,13 +57,13 @@ def open_parameter_file(file_path: Path) -> dict: The values are stripped of any trailing ")" and leading or trailing quotes. Parameters - ---------- + ------------ file_path : Path The path to the parameter file. Returns - ------- - dict + -------- + Dict A dictionary containing the parameters from the file. """ with open(file_path, "r") as f: @@ -94,5 +95,62 @@ def find_layer_index(viewer: napari.Viewer, layer_name: str) -> int: def get_image_layer_names(viewer: napari.Viewer) -> List[str]: """ Returns a list of the names of the napari image layers in the viewer. + + Parameters + ------------ + viewer : napari.Viewer + The napari viewer containing the image layers. + + Returns + -------- + List[str] + A list of the names of the image layers in the viewer. """ return [layer.name for layer in viewer.layers] + + +def calculate_rotated_bounding_box( + image_shape: Tuple[int, int, int], rotation_matrix: npt.NDArray +) -> Tuple[int, int, int]: + """ + Calculates the bounding box of the rotated image. + + This function calculates the bounding box of the rotated image given the + image shape and rotation matrix. The bounding box is calculated by + transforming the corners of the image and finding the minimum and maximum + values of the transformed corners. + + Parameters + ------------ + image_shape : Tuple[int, int, int] + The shape of the image. + rotation_matrix : npt.NDArray + The rotation matrix. + + Returns + -------- + Tuple[int, int, int] + The bounding box of the rotated image. + """ + corners = np.array( + [ + [0, 0, 0, 1], + [image_shape[0], 0, 0, 1], + [0, image_shape[1], 0, 1], + [0, 0, image_shape[2], 1], + [image_shape[0], image_shape[1], 0, 1], + [image_shape[0], 0, image_shape[2], 1], + [0, image_shape[1], image_shape[2], 1], + [image_shape[0], image_shape[1], image_shape[2], 1], + ] + ) + + transformed_corners = rotation_matrix @ corners.T + min_corner = np.min(transformed_corners, axis=1) + max_corner = np.max(transformed_corners, axis=1) + + return ( + int(np.round(max_corner[0] - min_corner[0])), + int(np.round(max_corner[1] - min_corner[1])), + int(np.round(max_corner[2] - min_corner[2])), + ) diff --git a/brainglobe_registration/widgets/adjust_moving_image_view.py b/brainglobe_registration/widgets/adjust_moving_image_view.py index c7ba3fd..111951d 100644 --- a/brainglobe_registration/widgets/adjust_moving_image_view.py +++ b/brainglobe_registration/widgets/adjust_moving_image_view.py @@ -1,4 +1,4 @@ -from qtpy.QtCore import Signal +from qtpy.QtCore import Qt, Signal from qtpy.QtWidgets import ( QDoubleSpinBox, QFormLayout, @@ -35,10 +35,16 @@ class AdjustMovingImageView(QWidget): reset_image_signal. _on_scale_image_button_click(): Emits the scale_image_signal with the entered pixel sizes. + _on_adjust_atlas_rotation(): + Emits the atlas_rotation_signal with the entered pitch, yaw, and roll. + _on_atlas_reset(): + Resets the pitch, yaw, and roll to 0 and emits the atlas_reset_signal. """ adjust_image_signal = Signal(int, int, float) scale_image_signal = Signal(float, float) + atlas_rotation_signal = Signal(float, float, float) + reset_atlas_signal = Signal() reset_image_signal = Signal() def __init__(self, parent=None): @@ -53,22 +59,44 @@ def __init__(self, parent=None): super().__init__(parent=parent) self.setLayout(QFormLayout()) + self.layout().setLabelAlignment(Qt.AlignLeft) offset_range = 2000 rotation_range = 360 - self.adjust_moving_image_voxel_size_x = QDoubleSpinBox(parent=self) - self.adjust_moving_image_voxel_size_x.setDecimals(2) - self.adjust_moving_image_voxel_size_x.setRange(0.01, 100.00) - self.adjust_moving_image_voxel_size_y = QDoubleSpinBox(parent=self) - self.adjust_moving_image_voxel_size_y.setDecimals(2) - self.adjust_moving_image_voxel_size_y.setRange(0.01, 100.00) + self.adjust_moving_image_pixel_size_x = QDoubleSpinBox(parent=self) + self.adjust_moving_image_pixel_size_x.setDecimals(3) + self.adjust_moving_image_pixel_size_x.setRange(0.001, 100.00) + self.adjust_moving_image_pixel_size_y = QDoubleSpinBox(parent=self) + self.adjust_moving_image_pixel_size_y.setDecimals(3) + self.adjust_moving_image_pixel_size_y.setRange(0.001, 100.00) self.scale_moving_image_button = QPushButton() self.scale_moving_image_button.setText("Scale Image") self.scale_moving_image_button.clicked.connect( self._on_scale_image_button_click ) + self.adjust_atlas_pitch = QDoubleSpinBox(parent=self) + self.adjust_atlas_pitch.setSingleStep(0.1) + self.adjust_atlas_pitch.setRange(-rotation_range, rotation_range) + + self.adjust_atlas_yaw = QDoubleSpinBox(parent=self) + self.adjust_atlas_yaw.setSingleStep(0.1) + self.adjust_atlas_yaw.setRange(-rotation_range, rotation_range) + + self.adjust_atlas_roll = QDoubleSpinBox(parent=self) + self.adjust_atlas_roll.setSingleStep(0.1) + self.adjust_atlas_roll.setRange(-rotation_range, rotation_range) + + self.adjust_atlas_rotation = QPushButton() + self.adjust_atlas_rotation.setText("Rotate Atlas") + self.adjust_atlas_rotation.clicked.connect( + self._on_adjust_atlas_rotation + ) + self.reset_atlas_button = QPushButton() + self.reset_atlas_button.setText("Reset Atlas") + self.reset_atlas_button.clicked.connect(self._on_atlas_reset) + self.adjust_moving_image_x = QSpinBox(parent=self) self.adjust_moving_image_x.setRange(-offset_range, offset_range) self.adjust_moving_image_x.valueChanged.connect(self._on_adjust_image) @@ -95,17 +123,24 @@ def __init__(self, parent=None): self.layout().addRow(QLabel("Adjust the moving image scale:")) self.layout().addRow( "Sample image X pixel size (\u03BCm / pixel):", - self.adjust_moving_image_voxel_size_x, + self.adjust_moving_image_pixel_size_x, ) self.layout().addRow( "Sample image Y pixel size (\u03BCm / pixel):", - self.adjust_moving_image_voxel_size_y, + self.adjust_moving_image_pixel_size_y, ) self.layout().addRow(self.scale_moving_image_button) - self.layout().addRow(QLabel("Adjust the sample image position:")) - self.layout().addRow("X offset (pixels):", self.adjust_moving_image_x) - self.layout().addRow("Y offset (pixels):", self.adjust_moving_image_y) + self.layout().addRow(QLabel("Adjust the atlas pitch and yaw: ")) + self.layout().addRow("Pitch:", self.adjust_atlas_pitch) + self.layout().addRow("Yaw:", self.adjust_atlas_yaw) + self.layout().addRow("Roll:", self.adjust_atlas_roll) + self.layout().addRow(self.adjust_atlas_rotation) + self.layout().addRow(self.reset_atlas_button) + + self.layout().addRow(QLabel("Adjust the moving image position: ")) + self.layout().addRow("X offset:", self.adjust_moving_image_x) + self.layout().addRow("Y offset:", self.adjust_moving_image_y) self.layout().addRow( "Rotation (degrees):", self.adjust_moving_image_rotate ) @@ -138,6 +173,26 @@ def _on_scale_image_button_click(self): Emit the scale_image_signal with the entered pixel sizes. """ self.scale_image_signal.emit( - self.adjust_moving_image_voxel_size_x.value(), - self.adjust_moving_image_voxel_size_y.value(), + self.adjust_moving_image_pixel_size_x.value(), + self.adjust_moving_image_pixel_size_y.value(), + ) + + def _on_adjust_atlas_rotation(self): + """ + Emit the atlas_rotation_signal with the entered pitch and yaw. + """ + self.atlas_rotation_signal.emit( + self.adjust_atlas_pitch.value(), + self.adjust_atlas_yaw.value(), + self.adjust_atlas_roll.value(), ) + + def _on_atlas_reset(self): + """ + Reset the pitch, yaw, and roll to 0 and emit the atlas_reset_signal. + """ + self.adjust_atlas_yaw.setValue(0) + self.adjust_atlas_pitch.setValue(0) + self.adjust_atlas_roll.setValue(0) + + self.reset_atlas_signal.emit() diff --git a/pyproject.toml b/pyproject.toml index 63db406..83e575c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,15 @@ classifiers = [ dependencies = [ "napari>=0.4.18", - "bg-atlasapi", - "brainglobe-utils>=0.3.4", + "brainglobe-atlasapi", + "brainglobe-utils>=0.4.3", + "dask", + "dask-image", + "itk-elastix", + "lxml_html_clean", "numpy", + "pytransform3d", "qtpy", - "itk-elastix", - "pytransform3d" ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index a7708ef..df32c05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,8 @@ import numpy as np import pytest -from bg_atlasapi import BrainGlobeAtlas -from bg_atlasapi import config as bg_config +from brainglobe_atlasapi import BrainGlobeAtlas +from brainglobe_atlasapi import config as bg_config from PIL import Image diff --git a/tests/test_adjust_moving_image_view.py b/tests/test_adjust_moving_image_view.py index fb2dc41..8eae56c 100644 --- a/tests/test_adjust_moving_image_view.py +++ b/tests/test_adjust_moving_image_view.py @@ -17,7 +17,7 @@ def adjust_moving_image_view() -> AdjustMovingImageView: def test_init(qtbot, adjust_moving_image_view): qtbot.addWidget(adjust_moving_image_view) - assert adjust_moving_image_view.layout().rowCount() == 9 + assert adjust_moving_image_view.layout().rowCount() == 15 @pytest.mark.parametrize( @@ -104,7 +104,7 @@ def test_reset_image_button_click(qtbot, adjust_moving_image_view): @pytest.mark.parametrize( "x_scale, y_scale", - [(2.5, 2.5), (10, 20), (10.221, 10.228)], + [(2.5, 2.5), (10, 20), (10.2212, 10.2289)], ) def test_scale_image_button_click( qtbot, adjust_moving_image_view, x_scale, y_scale @@ -114,12 +114,48 @@ def test_scale_image_button_click( with qtbot.waitSignal( adjust_moving_image_view.scale_image_signal, timeout=1000 ) as blocker: - adjust_moving_image_view.adjust_moving_image_voxel_size_x.setValue( + adjust_moving_image_view.adjust_moving_image_pixel_size_x.setValue( x_scale ) - adjust_moving_image_view.adjust_moving_image_voxel_size_y.setValue( + adjust_moving_image_view.adjust_moving_image_pixel_size_y.setValue( y_scale ) adjust_moving_image_view.scale_moving_image_button.click() - assert blocker.args == [round(x_scale, 2), round(y_scale, 2)] + assert blocker.args == [round(x_scale, 3), round(y_scale, 3)] + + +def test_atlas_rotation_changed( + qtbot, + adjust_moving_image_view, +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.atlas_rotation_signal, timeout=1000 + ) as blocker: + adjust_moving_image_view.adjust_atlas_pitch.setValue(10) + adjust_moving_image_view.adjust_atlas_yaw.setValue(20) + adjust_moving_image_view.adjust_atlas_roll.setValue(30) + + adjust_moving_image_view.adjust_atlas_rotation.click() + + assert blocker.args == [10, 20, 30] + + +def test_atlas_reset_button_click( + qtbot, + adjust_moving_image_view, +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.reset_atlas_signal, timeout=1000 + ): + adjust_moving_image_view.reset_atlas_button.click() + + assert ( + adjust_moving_image_view.adjust_atlas_pitch.value() == 0 + and adjust_moving_image_view.adjust_atlas_yaw.value() == 0 + and adjust_moving_image_view.adjust_atlas_roll.value() == 0 + ) diff --git a/tests/test_registration_widget.py b/tests/test_registration_widget.py index 2b74c98..828c4aa 100644 --- a/tests/test_registration_widget.py +++ b/tests/test_registration_widget.py @@ -12,6 +12,29 @@ def registration_widget(make_napari_viewer_with_images): return widget +@pytest.fixture() +def registration_widget_with_example_atlas(make_napari_viewer_with_images): + """ + Create an initialised RegistrationWidget with the "example_mouse_100um" + loaded. + + Parameters + ------------ + make_napari_viewer_with_images for testing + Fixture that creates a napari viewer + """ + viewer = make_napari_viewer_with_images + + widget = RegistrationWidget(viewer) + + # Based on the downloaded atlases by the fixture in conftest.py, + # the example atlas will be in the third position. + example_atlas_index = 2 + widget._on_atlas_dropdown_index_changed(example_atlas_index) + + return widget + + def test_registration_widget(make_napari_viewer_with_images): widget = RegistrationWidget(make_napari_viewer_with_images) @@ -132,3 +155,62 @@ def test_scale_moving_image( current_size[0] * y_scale_factor, current_size[1] * x_scale_factor, ) + + +@pytest.mark.parametrize( + "pitch, yaw, roll, expected_shape", + [ + (0, 0, 0, (132, 80, 114)), + (45, 0, 0, (150, 150, 114)), + (0, 45, 0, (174, 80, 174)), + (0, 0, 45, (132, 137, 137)), + ], +) +def test_on_adjust_atlas_rotation( + registration_widget_with_example_atlas, + pitch, + yaw, + roll, + expected_shape, +): + reg_widget = registration_widget_with_example_atlas + atlas_shape = reg_widget._atlas.reference.shape + + reg_widget._on_adjust_atlas_rotation(pitch, yaw, roll) + + assert reg_widget._atlas_data_layer.data.shape == expected_shape + assert reg_widget._atlas_annotations_layer.data.shape == expected_shape + assert reg_widget._atlas.reference.shape == atlas_shape + + +def test_on_adjust_atlas_rotation_no_atlas(registration_widget, mocker): + mocked_show_error = mocker.patch( + "brainglobe_registration.registration_widget.show_error" + ) + registration_widget._on_adjust_atlas_rotation(10, 10, 10) + mocked_show_error.assert_called_once_with( + "No atlas selected. Please select an atlas before rotating" + ) + + +def test_on_atlas_reset(registration_widget_with_example_atlas): + reg_widget = registration_widget_with_example_atlas + atlas_shape = reg_widget._atlas.reference.shape + reg_widget._on_adjust_atlas_rotation(10, 10, 10) + + reg_widget._on_atlas_reset() + + assert reg_widget._atlas_data_layer.data.shape == atlas_shape + assert reg_widget._atlas.reference.shape == atlas_shape + assert reg_widget._atlas_annotations_layer.data.shape == atlas_shape + + +def test_on_atlas_reset_no_atlas(registration_widget, mocker): + mocked_show_error = mocker.patch( + "brainglobe_registration.registration_widget.show_error" + ) + + registration_widget._on_atlas_reset() + mocked_show_error.assert_called_once_with( + "No atlas selected. Please select an atlas before resetting" + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index d857e09..89068e7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,7 @@ from brainglobe_registration.utils.utils import ( adjust_napari_image_layer, + calculate_rotated_bounding_box, find_layer_index, get_image_layer_names, open_parameter_file, @@ -93,3 +94,32 @@ def test_get_image_layer_names(make_napari_viewer_with_images): assert len(layer_names) == 2 assert layer_names == ["moving_image", "atlas_image"] + + +@pytest.mark.parametrize( + "basis, rotation, expected_bounds", + [ + (0, 0, (50, 100, 200)), + (0, 90, (50, 200, 100)), + (0, 180, (50, 100, 200)), + (0, 45, (50, 212, 212)), + (1, 0, (50, 100, 200)), + (1, 90, (200, 100, 50)), + (1, 180, (50, 100, 200)), + (1, 45, (177, 100, 177)), + (2, 0, (50, 100, 200)), + (2, 90, (100, 50, 200)), + (2, 180, (50, 100, 200)), + (2, 45, (106, 106, 200)), + ], +) +def test_calculate_rotated_bounding_box(basis, rotation, expected_bounds): + image_shape = (50, 100, 200) + rotation_matrix = np.eye(4) + rotation_matrix[:-1, :-1] = active_matrix_from_angle( + basis, np.deg2rad(rotation) + ) + + result_shape = calculate_rotated_bounding_box(image_shape, rotation_matrix) + + assert result_shape == expected_bounds