Skip to content

Commit

Permalink
flat field operation change to compute function and added error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ashmeigh committed Jun 10, 2024
1 parent 42e4167 commit e8040be
Showing 1 changed file with 31 additions and 62 deletions.
93 changes: 31 additions & 62 deletions mantidimaging/core/operations/flat_fielding/flat_fielding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from mantidimaging import helper as h
from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup
from mantidimaging.core.parallel import utility as pu, shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.parallel import shared as ps
from mantidimaging.gui.utility.qt_helpers import Type
from mantidimaging.gui.widgets.dataset_selector import DatasetSelectorWidgetView

Expand Down Expand Up @@ -125,24 +124,33 @@ def filter_func(images: ImageStack,
if dark_avg is None:
dark_avg = np.zeros_like(flat_avg)

if flat_avg is not None and dark_avg is not None:
if flat_avg.ndim != 2 or dark_avg.ndim != 2:
raise ValueError(
f"Incorrect shape of the flat image ({flat_avg.shape}) or dark image ({dark_avg.shape}) "
f"which should match the shape of the sample images ({images.data.shape[1:]})")

if not (images.data.shape[1:] == flat_avg.shape == dark_avg.shape):
raise ValueError(f"Not all images are the expected shape: {images.data.shape[1:]}, instead "
f"flat had shape: {flat_avg.shape}, and dark had shape: {dark_avg.shape}")

progress = Progress.ensure_instance(progress,
num_steps=images.data.shape[0],
task_name='Background Correction')
_execute(images, flat_avg, dark_avg, progress)
params = {'flat_avg': flat_avg, 'dark_avg': dark_avg}
ps.run_compute_func(FlatFieldFilter.compute_function, len(images.data), [images.shared_array], params)

h.check_data_stack(images)
return images

@staticmethod
def compute_function(index: int, array: np.ndarray, params: dict):
flat_avg = params['flat_avg']
dark_avg = params['dark_avg']

norm_divide = flat_avg - dark_avg
norm_divide[norm_divide == 0] = MINIMUM_PIXEL_VALUE
array[index] -= dark_avg
array[index] /= norm_divide

def _divide(data, norm_divide):
np.true_divide(data, norm_divide, out=data)

def _subtract(data, dark=None):
# specify out to do in place, otherwise the data is copied
np.subtract(data, dark, out=data)

def _norm_divide(flat: np.ndarray, dark: np.ndarray) -> np.ndarray:
# subtract dark from flat
return np.subtract(flat, dark)

@staticmethod
def register_gui(form, on_change, view) -> dict[str, Any]:
from mantidimaging.gui.utility import add_property_to_form
Expand Down Expand Up @@ -235,10 +243,13 @@ def register_gui(form, on_change, view) -> dict[str, Any]:
}

@staticmethod
def execute_wrapper( # type: ignore
flat_before_widget: DatasetSelectorWidgetView, flat_after_widget: DatasetSelectorWidgetView,
dark_before_widget: DatasetSelectorWidgetView, dark_after_widget: DatasetSelectorWidgetView,
selected_flat_fielding_widget: QComboBox, use_dark_widget: QCheckBox) -> partial:
def execute_wrapper(**kwargs):
flat_before_widget = kwargs['flat_before_widget']
flat_after_widget = kwargs['flat_after_widget']
dark_before_widget = kwargs['dark_before_widget']
dark_after_widget = kwargs['dark_after_widget']
selected_flat_fielding_widget = kwargs['selected_flat_fielding_widget']
use_dark_widget = kwargs['use_dark_widget']

flat_before_images = BaseFilter.get_images_from_stack(flat_before_widget, "flat before")
flat_after_images = BaseFilter.get_images_from_stack(flat_after_widget, "flat after")
Expand All @@ -247,7 +258,6 @@ def execute_wrapper( # type: ignore
dark_after_images = BaseFilter.get_images_from_stack(dark_after_widget, "dark after")

selected_flat_fielding = selected_flat_fielding_widget.currentText()

use_dark = use_dark_widget.isChecked()

return partial(FlatFieldFilter.filter_func,
Expand Down Expand Up @@ -276,44 +286,3 @@ def validate_execute_kwargs(kwargs):
@staticmethod
def group_name() -> FilterGroup:
return FilterGroup.Basic


def _divide(data, norm_divide):
np.true_divide(data, norm_divide, out=data)


def _subtract(data, dark=None):
# specify out to do in place, otherwise the data is copied
np.subtract(data, dark, out=data)


def _norm_divide(flat: np.ndarray, dark: np.ndarray) -> np.ndarray:
# subtract dark from flat
return np.subtract(flat, dark)


def _execute(images: ImageStack, flat=None, dark=None, progress=None):
with progress:
progress.update(msg="Applying background correction")

if images.uses_shared_memory:
shared_dark = pu.copy_into_shared_memory(dark)
norm_divide = pu.copy_into_shared_memory(_norm_divide(flat, dark))
else:
shared_dark = pu.SharedArray(dark, None)
norm_divide = pu.SharedArray(_norm_divide(flat, dark), None)

# prevent divide-by-zero issues, and negative pixels make no sense
norm_divide.array[norm_divide.array == 0] = MINIMUM_PIXEL_VALUE

# subtract the dark from all images
do_subtract = ps.create_partial(_subtract, fwd_function=ps.inplace_second_2d)
arrays = [images.shared_array, shared_dark]
ps.execute(do_subtract, arrays, images.data.shape[0], progress)

# divide the data by (flat - dark)
do_divide = ps.create_partial(_divide, fwd_function=ps.inplace_second_2d)
arrays = [images.shared_array, norm_divide]
ps.execute(do_divide, arrays, images.data.shape[0], progress)

return images

0 comments on commit e8040be

Please sign in to comment.