From e8040bed2769c8c8dd8dcd9afe9656964c7b86ce Mon Sep 17 00:00:00 2001 From: ashmeigh Date: Mon, 10 Jun 2024 16:28:46 +0100 Subject: [PATCH] flat field operation change to compute function and added error handling --- .../operations/flat_fielding/flat_fielding.py | 93 +++++++------------ 1 file changed, 31 insertions(+), 62 deletions(-) diff --git a/mantidimaging/core/operations/flat_fielding/flat_fielding.py b/mantidimaging/core/operations/flat_fielding/flat_fielding.py index dc6d4dda5d3..34ebdab15d6 100644 --- a/mantidimaging/core/operations/flat_fielding/flat_fielding.py +++ b/mantidimaging/core/operations/flat_fielding/flat_fielding.py @@ -10,8 +10,7 @@ from mantidimaging import helper as h from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup -from mantidimaging.core.parallel import utility as pu, shared as ps -from mantidimaging.core.utility.progress_reporting import Progress +from mantidimaging.core.parallel import shared as ps from mantidimaging.gui.utility.qt_helpers import Type from mantidimaging.gui.widgets.dataset_selector import DatasetSelectorWidgetView @@ -125,24 +124,33 @@ def filter_func(images: ImageStack, if dark_avg is None: dark_avg = np.zeros_like(flat_avg) - if flat_avg is not None and dark_avg is not None: - if flat_avg.ndim != 2 or dark_avg.ndim != 2: - raise ValueError( - f"Incorrect shape of the flat image ({flat_avg.shape}) or dark image ({dark_avg.shape}) " - f"which should match the shape of the sample images ({images.data.shape[1:]})") - - if not (images.data.shape[1:] == flat_avg.shape == dark_avg.shape): - raise ValueError(f"Not all images are the expected shape: {images.data.shape[1:]}, instead " - f"flat had shape: {flat_avg.shape}, and dark had shape: {dark_avg.shape}") - - progress = Progress.ensure_instance(progress, - num_steps=images.data.shape[0], - task_name='Background Correction') - _execute(images, flat_avg, dark_avg, progress) + params = {'flat_avg': flat_avg, 'dark_avg': dark_avg} + ps.run_compute_func(FlatFieldFilter.compute_function, len(images.data), [images.shared_array], params) h.check_data_stack(images) return images + @staticmethod + def compute_function(index: int, array: np.ndarray, params: dict): + flat_avg = params['flat_avg'] + dark_avg = params['dark_avg'] + + norm_divide = flat_avg - dark_avg + norm_divide[norm_divide == 0] = MINIMUM_PIXEL_VALUE + array[index] -= dark_avg + array[index] /= norm_divide + + def _divide(data, norm_divide): + np.true_divide(data, norm_divide, out=data) + + def _subtract(data, dark=None): + # specify out to do in place, otherwise the data is copied + np.subtract(data, dark, out=data) + + def _norm_divide(flat: np.ndarray, dark: np.ndarray) -> np.ndarray: + # subtract dark from flat + return np.subtract(flat, dark) + @staticmethod def register_gui(form, on_change, view) -> dict[str, Any]: from mantidimaging.gui.utility import add_property_to_form @@ -235,10 +243,13 @@ def register_gui(form, on_change, view) -> dict[str, Any]: } @staticmethod - def execute_wrapper( # type: ignore - flat_before_widget: DatasetSelectorWidgetView, flat_after_widget: DatasetSelectorWidgetView, - dark_before_widget: DatasetSelectorWidgetView, dark_after_widget: DatasetSelectorWidgetView, - selected_flat_fielding_widget: QComboBox, use_dark_widget: QCheckBox) -> partial: + def execute_wrapper(**kwargs): + flat_before_widget = kwargs['flat_before_widget'] + flat_after_widget = kwargs['flat_after_widget'] + dark_before_widget = kwargs['dark_before_widget'] + dark_after_widget = kwargs['dark_after_widget'] + selected_flat_fielding_widget = kwargs['selected_flat_fielding_widget'] + use_dark_widget = kwargs['use_dark_widget'] flat_before_images = BaseFilter.get_images_from_stack(flat_before_widget, "flat before") flat_after_images = BaseFilter.get_images_from_stack(flat_after_widget, "flat after") @@ -247,7 +258,6 @@ def execute_wrapper( # type: ignore dark_after_images = BaseFilter.get_images_from_stack(dark_after_widget, "dark after") selected_flat_fielding = selected_flat_fielding_widget.currentText() - use_dark = use_dark_widget.isChecked() return partial(FlatFieldFilter.filter_func, @@ -276,44 +286,3 @@ def validate_execute_kwargs(kwargs): @staticmethod def group_name() -> FilterGroup: return FilterGroup.Basic - - -def _divide(data, norm_divide): - np.true_divide(data, norm_divide, out=data) - - -def _subtract(data, dark=None): - # specify out to do in place, otherwise the data is copied - np.subtract(data, dark, out=data) - - -def _norm_divide(flat: np.ndarray, dark: np.ndarray) -> np.ndarray: - # subtract dark from flat - return np.subtract(flat, dark) - - -def _execute(images: ImageStack, flat=None, dark=None, progress=None): - with progress: - progress.update(msg="Applying background correction") - - if images.uses_shared_memory: - shared_dark = pu.copy_into_shared_memory(dark) - norm_divide = pu.copy_into_shared_memory(_norm_divide(flat, dark)) - else: - shared_dark = pu.SharedArray(dark, None) - norm_divide = pu.SharedArray(_norm_divide(flat, dark), None) - - # prevent divide-by-zero issues, and negative pixels make no sense - norm_divide.array[norm_divide.array == 0] = MINIMUM_PIXEL_VALUE - - # subtract the dark from all images - do_subtract = ps.create_partial(_subtract, fwd_function=ps.inplace_second_2d) - arrays = [images.shared_array, shared_dark] - ps.execute(do_subtract, arrays, images.data.shape[0], progress) - - # divide the data by (flat - dark) - do_divide = ps.create_partial(_divide, fwd_function=ps.inplace_second_2d) - arrays = [images.shared_array, norm_divide] - ps.execute(do_divide, arrays, images.data.shape[0], progress) - - return images