Skip to content

Commit

Permalink
Merge pull request #2954 from pnuu/bugfix-rayleigh-reflectance-dtype
Browse files Browse the repository at this point in the history
Fix Rayleigh correction to use the same datatype as the input data
  • Loading branch information
pnuu authored Oct 25, 2024
2 parents 309c874 + 3f1076a commit 5f4e4c1
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 28 deletions.
6 changes: 4 additions & 2 deletions satpy/modifiers/atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __call__(self, projectables, optional_datasets=None, **info):
projectables = projectables + (optional_datasets or [])
if len(projectables) != 6:
vis, red = self.match_data_arrays(projectables)
sata, satz, suna, sunz = get_angles(vis)
# Adjust the angle data precision to match the data
# This does not affect the accuracy visibly
sata, satz, suna, sunz = [d.astype(vis.dtype) for d in get_angles(vis)]
else:
vis, red, sata, satz, suna, sunz = self.match_data_arrays(projectables)
# First make sure the two azimuth angles are in the range 0-360:
Expand All @@ -97,7 +99,7 @@ def __call__(self, projectables, optional_datasets=None, **info):
aerosol_type = self.attrs.get("aerosol_type", "marine_clean_aerosol")
reduce_lim_low = abs(self.attrs.get("reduce_lim_low", 70))
reduce_lim_high = abs(self.attrs.get("reduce_lim_high", 105))
reduce_strength = np.clip(self.attrs.get("reduce_strength", 0), 0, 1)
reduce_strength = np.clip(self.attrs.get("reduce_strength", 0), 0, 1).astype(vis.dtype)

logger.info("Removing Rayleigh scattering with atmosphere '%s' and "
"aerosol type '%s' for '%s'",
Expand Down
32 changes: 23 additions & 9 deletions satpy/tests/test_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,28 @@ def test_self_sharpened_no_high_res(self):
with pytest.raises(ValueError, match="SelfSharpenedRGB requires at least one high resolution band, not 'None'"):
comp((self.ds1, self.ds2, self.ds3))

def test_basic_no_high_res(self):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_basic_no_high_res(self, dtype):
"""Test that three datasets can be passed without optional high res."""
from satpy.composites import RatioSharpenedRGB
comp = RatioSharpenedRGB(name="true_color")
res = comp((self.ds1, self.ds2, self.ds3))
res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype)))
assert res.shape == (3, 2, 2)
assert res.dtype == dtype
assert res.values.dtype == dtype

def test_basic_no_sharpen(self):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_basic_no_sharpen(self, dtype):
"""Test that color None does no sharpening."""
from satpy.composites import RatioSharpenedRGB
comp = RatioSharpenedRGB(name="true_color", high_resolution_band=None)
res = comp((self.ds1, self.ds2, self.ds3), optional_datasets=(self.ds4,))
res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype)),
optional_datasets=(self.ds4.astype(dtype),))
assert res.shape == (3, 2, 2)
assert res.dtype == dtype
assert res.values.dtype == dtype

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize(
("high_resolution_band", "neutral_resolution_band", "exp_r", "exp_g", "exp_b"),
[
Expand Down Expand Up @@ -300,22 +308,26 @@ def test_basic_no_sharpen(self):
np.array([[1.0, 1.0], [np.nan, 1.0]], dtype=np.float64))
]
)
def test_ratio_sharpening(self, high_resolution_band, neutral_resolution_band, exp_r, exp_g, exp_b):
def test_ratio_sharpening(self, high_resolution_band, neutral_resolution_band, exp_r, exp_g, exp_b, dtype):
"""Test RatioSharpenedRGB by different groups of high_resolution_band and neutral_resolution_band."""
from satpy.composites import RatioSharpenedRGB
comp = RatioSharpenedRGB(name="true_color", high_resolution_band=high_resolution_band,
neutral_resolution_band=neutral_resolution_band)
res = comp((self.ds1, self.ds2, self.ds3), optional_datasets=(self.ds4,))
res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype)),
optional_datasets=(self.ds4.astype(dtype),))

assert "units" not in res.attrs
assert isinstance(res, xr.DataArray)
assert isinstance(res.data, da.Array)
assert res.dtype == dtype

data = res.values
np.testing.assert_allclose(data[0], exp_r, rtol=1e-5)
np.testing.assert_allclose(data[1], exp_g, rtol=1e-5)
np.testing.assert_allclose(data[2], exp_b, rtol=1e-5)
assert res.dtype == dtype

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize(
("exp_shape", "exp_r", "exp_g", "exp_b"),
[
Expand All @@ -325,17 +337,19 @@ def test_ratio_sharpening(self, high_resolution_band, neutral_resolution_band, e
np.array([[16 / 3, 16 / 3], [16 / 3, 0]], dtype=np.float64))
]
)
def test_self_sharpened_basic(self, exp_shape, exp_r, exp_g, exp_b):
def test_self_sharpened_basic(self, exp_shape, exp_r, exp_g, exp_b, dtype):
"""Test that three datasets can be passed without optional high res."""
from satpy.composites import SelfSharpenedRGB
comp = SelfSharpenedRGB(name="true_color")
res = comp((self.ds1, self.ds2, self.ds3))
data = res.values
res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype)))
assert res.dtype == dtype

data = res.values
assert data.shape == exp_shape
np.testing.assert_allclose(data[0], exp_r, rtol=1e-5)
np.testing.assert_allclose(data[1], exp_g, rtol=1e-5)
np.testing.assert_allclose(data[2], exp_b, rtol=1e-5)
assert data.dtype == dtype


class TestDifferenceCompositor(unittest.TestCase):
Expand Down
58 changes: 41 additions & 17 deletions satpy/tests/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,29 +135,46 @@ def test_basic_default_not_provided(self, sunz_ds1, as_32bit):
assert res.dtype == res_np.dtype
assert "y" not in res.coords
assert "x" not in res.coords
if as_32bit:
assert res.dtype == np.float32

def test_basic_lims_not_provided(self, sunz_ds1):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_basic_lims_not_provided(self, sunz_ds1, dtype):
"""Test custom limits when SZA isn't provided."""
from satpy.modifiers.geometry import SunZenithCorrector
comp = SunZenithCorrector(name="sza_test", modifiers=tuple(), correction_limit=90)
res = comp((sunz_ds1,), test_attr="test")
np.testing.assert_allclose(res.values, np.array([[66.853262, 68.168939], [66.30742, 67.601493]]))

res = comp((sunz_ds1.astype(dtype),), test_attr="test")
expected = np.array([[66.853262, 68.168939], [66.30742, 67.601493]], dtype=dtype)
values = res.values
np.testing.assert_allclose(values, expected, rtol=1e-5)
assert res.dtype == dtype
assert values.dtype == dtype

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("data_arr", [lazy_fixture("sunz_ds1"), lazy_fixture("sunz_ds1_stacked")])
def test_basic_default_provided(self, data_arr, sunz_sza):
def test_basic_default_provided(self, data_arr, sunz_sza, dtype):
"""Test default limits when SZA is provided."""
from satpy.modifiers.geometry import SunZenithCorrector
comp = SunZenithCorrector(name="sza_test", modifiers=tuple())
res = comp((data_arr, sunz_sza), test_attr="test")
np.testing.assert_allclose(res.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]))

res = comp((data_arr.astype(dtype), sunz_sza.astype(dtype)), test_attr="test")
expected = np.array([[22.401667, 22.31777], [22.437503, 22.353533]], dtype=dtype)
values = res.values
np.testing.assert_allclose(values, expected)
assert res.dtype == dtype
assert values.dtype == dtype

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("data_arr", [lazy_fixture("sunz_ds1"), lazy_fixture("sunz_ds1_stacked")])
def test_basic_lims_provided(self, data_arr, sunz_sza):
def test_basic_lims_provided(self, data_arr, sunz_sza, dtype):
"""Test custom limits when SZA is provided."""
from satpy.modifiers.geometry import SunZenithCorrector
comp = SunZenithCorrector(name="sza_test", modifiers=tuple(), correction_limit=90)
res = comp((data_arr, sunz_sza), test_attr="test")
np.testing.assert_allclose(res.values, np.array([[66.853262, 68.168939], [66.30742, 67.601493]]))
res = comp((data_arr.astype(dtype), sunz_sza.astype(dtype)), test_attr="test")
expected = np.array([[66.853262, 68.168939], [66.30742, 67.601493]], dtype=dtype)
values = res.values
np.testing.assert_allclose(values, expected, rtol=1e-5)
assert res.dtype == dtype
assert values.dtype == dtype

def test_imcompatible_areas(self, sunz_ds2, sunz_sza):
"""Test sunz correction on incompatible areas."""
Expand Down Expand Up @@ -502,6 +519,7 @@ def _create_test_data(self, name, wavelength, resolution):
})
return input_band, red_band, angle1, angle1, angle1, angle1

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize(
("name", "wavelength", "resolution", "aerosol_type", "reduce_lim_low", "reduce_lim_high", "reduce_strength",
"exp_mean", "exp_unique"),
Expand All @@ -521,7 +539,7 @@ def _create_test_data(self, name, wavelength, resolution):
]
)
def test_rayleigh_corrector(self, name, wavelength, resolution, aerosol_type, reduce_lim_low, reduce_lim_high,
reduce_strength, exp_mean, exp_unique):
reduce_strength, exp_mean, exp_unique, dtype):
"""Test PSPRayleighReflectance with fake data."""
from satpy.modifiers.atmosphere import PSPRayleighReflectance
ray_cor = PSPRayleighReflectance(name=name, atmosphere="us-standard", aerosol_types=aerosol_type,
Expand All @@ -535,42 +553,48 @@ def test_rayleigh_corrector(self, name, wavelength, resolution, aerosol_type, re
assert ray_cor.attrs["reduce_strength"] == reduce_strength

input_band, red_band, *_ = self._create_test_data(name, wavelength, resolution)
res = ray_cor([input_band, red_band])
res = ray_cor([input_band.astype(dtype), red_band.astype(dtype)])

assert isinstance(res, xr.DataArray)
assert isinstance(res.data, da.Array)
assert res.dtype == dtype

data = res.values
unique = np.unique(data[~np.isnan(data)])
np.testing.assert_allclose(np.nanmean(data), exp_mean, rtol=1e-5)
assert data.shape == (3, 5)
np.testing.assert_allclose(unique, exp_unique, rtol=1e-5)
assert data.dtype == dtype

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("as_optionals", [False, True])
def test_rayleigh_with_angles(self, as_optionals):
def test_rayleigh_with_angles(self, as_optionals, dtype):
"""Test PSPRayleighReflectance with angles provided."""
from satpy.modifiers.atmosphere import PSPRayleighReflectance
aerosol_type = "rayleigh_only"
ray_cor = PSPRayleighReflectance(name="B01", atmosphere="us-standard", aerosol_types=aerosol_type)
prereqs, opt_prereqs = self._get_angles_prereqs_and_opts(as_optionals)
prereqs, opt_prereqs = self._get_angles_prereqs_and_opts(as_optionals, dtype)
with mock.patch("satpy.modifiers.atmosphere.get_angles") as get_angles:
res = ray_cor(prereqs, opt_prereqs)
get_angles.assert_not_called()

assert isinstance(res, xr.DataArray)
assert isinstance(res.data, da.Array)
assert res.dtype == dtype

data = res.values
unique = np.unique(data[~np.isnan(data)])
np.testing.assert_allclose(unique, np.array([-75.0, -37.71298492, 31.14350754]), rtol=1e-5)
assert data.shape == (3, 5)
assert data.dtype == dtype

def _get_angles_prereqs_and_opts(self, as_optionals):
def _get_angles_prereqs_and_opts(self, as_optionals, dtype):
wavelength = (0.45, 0.47, 0.49)
resolution = 1000
input_band, red_band, *angles = self._create_test_data("B01", wavelength, resolution)
prereqs = [input_band, red_band]
prereqs = [input_band.astype(dtype), red_band.astype(dtype)]
opt_prereqs = []
angles = [a.astype(dtype) for a in angles]
if as_optionals:
opt_prereqs = angles
else:
Expand Down

0 comments on commit 5f4e4c1

Please sign in to comment.