From 7bef48a13fbe18a7f8e3f23c40de532f562c9784 Mon Sep 17 00:00:00 2001 From: Matic Lubej Date: Fri, 25 Aug 2023 14:14:15 +0200 Subject: [PATCH] Split cloud mask task (#722) * add simplified cloud mask task, rename legacy task * update tests for legacy cloud mask, add test for simple cloud mask task * remove unnecessary method * take review comments into account * revert band indices parsing --- eolearn/mask/cloud_mask.py | 128 ++++++++++++++++++++++++++++++++++ tests/mask/test_cloud_mask.py | 26 +++++-- 2 files changed, 149 insertions(+), 5 deletions(-) diff --git a/eolearn/mask/cloud_mask.py b/eolearn/mask/cloud_mask.py index 0d7bbb6c..bb5c3f41 100644 --- a/eolearn/mask/cloud_mask.py +++ b/eolearn/mask/cloud_mask.py @@ -41,6 +41,134 @@ def predict_proba(self, X: np.ndarray) -> np.ndarray: # noqa: N803 class CloudMaskTask(EOTask): + """Cloud masking with the s2cloudless model. Outputs a cloud mask and optionally the cloud probabilities.""" + + MODELS_FOLDER = os.path.join(os.path.dirname(__file__), "models") + CLASSIFIER_NAME = "pixel_s2_cloud_detector_lightGBM_v0.2.txt" + + def __init__( + self, + data_feature: tuple[FeatureType, str], + valid_data_feature: tuple[FeatureType, str], + output_mask_feature: tuple[FeatureType, str], + output_proba_feature: tuple[FeatureType, str] | None = None, + all_bands: bool = True, + threshold: float = 0.4, + average_over: int | None = 4, + dilation_size: int | None = 2, + ): + """ + :param data_feature: A data feature which stores raw Sentinel-2 reflectance bands. + :param valid_data_feature: A mask feature which indicates whether data is valid. + :param output_mask_feature: The output feature containing cloud masks. + :param output_proba_feature: The output feature containing cloud probabilities. By default this is not saved. + :param all_bands: Flag which indicates whether images will consist of all 13 Sentinel-2 L1C bands or only + the required 10. + :param threshold: Cloud probability threshold for the classifier. + :param average_over: Size of the pixel neighbourhood used in the averaging post-processing step. Set to `None` + to skip this post-processing step. + :param dilation_size: Size of the dilation post-processing step. Set to `None` to skip this post-processing + step. + """ + self.data_feature = self.parse_feature(data_feature) + self.data_indices = (0, 1, 3, 4, 7, 8, 9, 10, 11, 12) if all_bands else tuple(range(10)) + self.valid_data_feature = self.parse_feature(valid_data_feature) + + self.output_mask_feature = self.parse_feature(output_mask_feature) + self.output_proba_feature = None + if output_proba_feature is not None: + self.output_proba_feature = self.parse_feature(output_proba_feature) + + self.threshold = threshold + + self.avg_kernel = None + if average_over is not None and average_over > 0: + self.avg_kernel = disk(average_over) / np.sum(disk(average_over)) + + self.dil_kernel = None + if dilation_size is not None and dilation_size > 0: + self.dil_kernel = disk(dilation_size).astype(np.uint8) + + self._classifier: ClassifierType | Booster | None = None + + @property + def classifier(self) -> ClassifierType | Booster: + """An instance of a custom-provided cloud classifier. Loaded only the first time it is required.""" + if self._classifier is None: + path = os.path.join(self.MODELS_FOLDER, self.CLASSIFIER_NAME) + self._classifier = Booster(model_file=path) + + return self._classifier + + @staticmethod + def _run_prediction(classifier: ClassifierType | Booster, features: np.ndarray) -> np.ndarray: + """Uses classifier object on given data""" + is_booster = isinstance(classifier, Booster) + + predict_method = classifier.predict if is_booster else classifier.predict_proba + prediction: np.ndarray = execute_with_mp_lock(predict_method, features) + + return prediction if is_booster else prediction[..., 1] + + def _average(self, data: np.ndarray) -> np.ndarray: + return cv2.filter2D(data.astype(np.float64), -1, self.avg_kernel, borderType=cv2.BORDER_REFLECT) + + def _dilate(self, data: np.ndarray) -> np.ndarray: + return (cv2.dilate(data.astype(np.uint8), self.dil_kernel) > 0).astype(np.uint8) + + def _average_all(self, data: np.ndarray) -> np.ndarray: + """Average over each spatial slice of data""" + if self.avg_kernel is not None: + return _apply_to_spatial_axes(self._average, data, (1, 2)) + + return data + + def _dilate_all(self, data: np.ndarray) -> np.ndarray: + """Dilate over each spatial slice of data""" + if self.dil_kernel is not None: + return _apply_to_spatial_axes(self._dilate, data, (1, 2)) + + return data + + def _do_single_temporal_cloud_detection(self, bands: np.ndarray) -> np.ndarray: + """Performs a cloud detection process on each scene separately""" + output_proba = [] + _, height, width, n_bands = bands.shape + + for img in bands: + features = img.reshape(height * width, n_bands) + proba = self._run_prediction(self.classifier, features) + output_proba.append(proba.reshape(height, width, 1)) + + return np.array(output_proba) + + def execute(self, eopatch: EOPatch) -> EOPatch: + """Add selected features (cloud probabilities and masks) to an EOPatch instance. + + :param eopatch: Input `EOPatch` instance + :return: `EOPatch` with additional features + """ + data = eopatch[self.data_feature][..., self.data_indices].astype(np.float32) + valid_data = eopatch[self.valid_data_feature].astype(bool) + + patch_bbox = eopatch.bbox + if patch_bbox is None: + raise ValueError("Cannot run cloud masking on an EOPatch without a BBox.") + + cloud_proba = self._do_single_temporal_cloud_detection(data) + + # Average over and threshold + cloud_mask = self._average_all(cloud_proba) >= self.threshold + cloud_mask = self._dilate_all(cloud_mask) + eopatch[self.output_mask_feature] = (cloud_mask * valid_data).astype(bool) + + if self.output_proba_feature is not None: + eopatch[self.output_proba_feature] = (cloud_proba * valid_data).astype(np.float32) + + return eopatch + + +class _OldCloudMaskTask(EOTask): """Cloud masking with an improved s2cloudless model and the SSIM-based multi-temporal classifier. Its intended output is a cloud mask that is based on the outputs of both diff --git a/tests/mask/test_cloud_mask.py b/tests/mask/test_cloud_mask.py index 34f8707a..74c1f163 100644 --- a/tests/mask/test_cloud_mask.py +++ b/tests/mask/test_cloud_mask.py @@ -12,7 +12,7 @@ from eolearn.core import FeatureType from eolearn.mask import CloudMaskTask -from eolearn.mask.cloud_mask import _get_window_indices +from eolearn.mask.cloud_mask import _get_window_indices, _OldCloudMaskTask @pytest.mark.parametrize( @@ -41,8 +41,8 @@ def test_window_indices_function(num_of_elements, middle_idx, window_size, expec assert len(test_list[min_idx:max_idx]) == min(num_of_elements, window_size) -def test_mono_temporal_cloud_detection(test_eopatch): - add_tcm = CloudMaskTask( +def test_legacy_mono_temporal_cloud_detection(test_eopatch): + add_tcm = _OldCloudMaskTask( data_feature=(FeatureType.DATA, "BANDS-S2-L1C"), all_bands=True, is_data_feature=(FeatureType.MASK, "IS_DATA"), @@ -58,8 +58,8 @@ def test_mono_temporal_cloud_detection(test_eopatch): assert_array_equal(eop_clm.data["CLP_TEST"], test_eopatch.data["CLP_S2C"]) -def test_multi_temporal_cloud_detection_downscaled(test_eopatch): - add_tcm = CloudMaskTask( +def test_legacy_multi_temporal_cloud_detection_downscaled(test_eopatch): + add_tcm = _OldCloudMaskTask( data_feature=(FeatureType.DATA, "BANDS-S2-L1C"), processing_resolution=120, mono_features=("CLP_TEST", "CLM_TEST"), @@ -90,3 +90,19 @@ def test_multi_temporal_cloud_detection_downscaled(test_eopatch): assert_array_equal(eop_clm.data["CLP_MULTI_TEST"], test_eopatch.data["CLP_MULTI"]) assert_array_equal(eop_clm.mask["CLM_MULTI_TEST"], test_eopatch.mask["CLM_MULTI"]) assert_array_equal(eop_clm.mask["CLM_INTERSSIM_TEST"], test_eopatch.mask["CLM_INTERSSIM"]) + + +def test_cloud_detection(test_eopatch): + add_tcm = CloudMaskTask( + data_feature=(FeatureType.DATA, "BANDS-S2-L1C"), + valid_data_feature=(FeatureType.MASK, "IS_DATA"), + output_mask_feature=(FeatureType.MASK, "CLM_TEST"), + output_proba_feature=(FeatureType.DATA, "CLP_TEST"), + threshold=0.4, + average_over=4, + dilation_size=2, + ) + eop_clm = add_tcm(test_eopatch) + + assert_array_equal(eop_clm.mask["CLM_TEST"], test_eopatch.mask["CLM_S2C"]) + assert_array_equal(eop_clm.data["CLP_TEST"], test_eopatch.data["CLP_S2C"])