diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index a44ca590cf..86efe1ffba 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -219,11 +219,12 @@ def cira_stretch(img, **kwargs): @exclude_alpha def _cira_stretch(band_data): - log_root = np.log10(0.0223) + dtype = band_data.dtype + log_root = np.log10(0.0223, dtype=dtype) denom = (1.0 - log_root) * 0.75 band_data *= 0.01 band_data = band_data.clip(np.finfo(float).eps) - band_data = np.log10(band_data) + band_data = np.log10(band_data, dtype=dtype) band_data -= log_root band_data /= denom return band_data diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 96176fda34..6b797f0015 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -34,30 +34,55 @@ def run_and_check_enhancement(func, data, expected, **kwargs): """Perform basic checks that apply to multiple tests.""" + pre_attrs = data.attrs + img = _get_enhanced_image(func, data, **kwargs) + + _assert_image(img, pre_attrs, func.__name__, "palettes" in kwargs) + _assert_image_data(img, expected) + + +def _get_enhanced_image(func, data, **kwargs): from trollimage.xrimage import XRImage - pre_attrs = data.attrs img = XRImage(data) func(img, **kwargs) + return img + + +def _assert_image(img, pre_attrs, func_name, has_palette): + assert isinstance(img.data, xr.DataArray) assert isinstance(img.data.data, da.Array) + old_keys = set(pre_attrs.keys()) # It is OK to have "enhancement_history" added new_keys = set(img.data.attrs.keys()) - {"enhancement_history"} # In case of palettes are used, _FillValue is added. # Colorize doesn't add the fill value, so ignore that - if "palettes" in kwargs and func.__name__ != "colorize": + if has_palette and func_name != "colorize": assert "_FillValue" in new_keys # Remove it from further comparisons new_keys = new_keys - {"_FillValue"} assert old_keys == new_keys - res_data_arr = img.data - assert isinstance(res_data_arr, xr.DataArray) - assert isinstance(res_data_arr.data, da.Array) - res_data = res_data_arr.data.compute() # mimics what xrimage geotiff writing does + +def _assert_image_data(img, expected, dtype=None): + # Compute the data to mimic what xrimage geotiff writing does + res_data = img.data.data.compute() assert not isinstance(res_data, da.Array) np.testing.assert_allclose(res_data, expected, atol=1.e-6, rtol=0) + if dtype: + assert img.data.dtype == dtype + assert res_data.dtype == dtype + + +def run_and_check_enhancement_with_dtype(func, data, expected, **kwargs): + """Perform basic checks that apply to multiple tests.""" + pre_attrs = data.attrs + img = _get_enhanced_image(func, data, **kwargs) + + _assert_image(img, pre_attrs, func.__name__, "palettes" in kwargs) + _assert_image_data(img, expected, dtype=data.dtype) def identical_decorator(func): @@ -109,14 +134,15 @@ def _calc_func(data): exp_data = exp_data[np.newaxis, :, :] run_and_check_enhancement(_enh_func, in_data, exp_data) - def test_cira_stretch(self): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_cira_stretch(self, dtype): """Test applying the cira_stretch.""" from satpy.enhancements import cira_stretch expected = np.array([[ [np.nan, -7.04045974, -7.04045974, 0.79630132, 0.95947296], - [1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]]) - run_and_check_enhancement(cira_stretch, self.ch1, expected) + [1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]], dtype=dtype) + run_and_check_enhancement_with_dtype(cira_stretch, self.ch1.astype(dtype), expected) def test_reinhard(self): """Test the reinhard algorithm."""