Skip to content

Commit

Permalink
Refactor run_and_check_enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
lahtinep committed Oct 25, 2024
1 parent 81ccc61 commit 5169923
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions satpy/tests/enhancement_tests/test_enhancements.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,57 @@
# - tmp_path


def run_and_check_enhancement(func, data, expected, match_dtype=False, **kwargs):
def run_and_check_enhancement(func, data, expected, **kwargs):
"""Perform basic checks that apply to multiple tests."""
pre_attrs = data.attrs
img = _get_enhanced_image(func, data, **kwargs)

_assert_image(img, pre_attrs, func.__name__, "palettes" in kwargs)
_assert_image_data(img, expected)


def _get_enhanced_image(func, data, **kwargs):
from trollimage.xrimage import XRImage

pre_attrs = data.attrs
img = XRImage(data)
func(img, **kwargs)

return img


def _assert_image(img, pre_attrs, func_name, has_palette):
assert isinstance(img.data, xr.DataArray)
assert isinstance(img.data.data, da.Array)

old_keys = set(pre_attrs.keys())
# It is OK to have "enhancement_history" added
new_keys = set(img.data.attrs.keys()) - {"enhancement_history"}
# In case of palettes are used, _FillValue is added.
# Colorize doesn't add the fill value, so ignore that
if "palettes" in kwargs and func.__name__ != "colorize":
if has_palette and func_name != "colorize":
assert "_FillValue" in new_keys
# Remove it from further comparisons
new_keys = new_keys - {"_FillValue"}
assert old_keys == new_keys

res_data_arr = img.data
assert isinstance(res_data_arr, xr.DataArray)
assert isinstance(res_data_arr.data, da.Array)
res_data = res_data_arr.data.compute() # mimics what xrimage geotiff writing does

def _assert_image_data(img, expected, dtype=None):
# Compute the data to mimic what xrimage geotiff writing does
res_data = img.data.data.compute()
assert not isinstance(res_data, da.Array)
np.testing.assert_allclose(res_data, expected, atol=1.e-6, rtol=0)
if match_dtype:
assert res_data_arr.dtype == data.dtype
assert res_data.dtype == data.dtype
if dtype:
assert img.data.dtype == dtype
assert res_data.dtype == dtype


def run_and_check_enhancement_with_dtype(func, data, expected, **kwargs):
"""Perform basic checks that apply to multiple tests."""
pre_attrs = data.attrs
img = _get_enhanced_image(func, data, **kwargs)

_assert_image(img, pre_attrs, func.__name__, "palettes" in kwargs)
_assert_image_data(img, expected, dtype=data.dtype)


def identical_decorator(func):
Expand Down Expand Up @@ -120,7 +142,7 @@ def test_cira_stretch(self, dtype):
expected = np.array([[
[np.nan, -7.04045974, -7.04045974, 0.79630132, 0.95947296],
[1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]], dtype=dtype)
run_and_check_enhancement(cira_stretch, self.ch1.astype(dtype), expected, match_dtype=True)
run_and_check_enhancement_with_dtype(cira_stretch, self.ch1.astype(dtype), expected)

def test_reinhard(self):
"""Test the reinhard algorithm."""
Expand Down

0 comments on commit 5169923

Please sign in to comment.