From c638c1f4430f818442d5d482b1a8088accc19583 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 13 Nov 2024 15:41:19 -0800 Subject: [PATCH] Add spatial threshold-based ensemble metrics PiperOrigin-RevId: 696294781 --- scripts/evaluate.py | 30 +- weatherbench2/metrics.py | 617 ++++++++++++++++++---------------- weatherbench2/metrics_test.py | 28 +- 3 files changed, 369 insertions(+), 306 deletions(-) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index c0e0331..1ec9659 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -512,13 +512,13 @@ def main(argv: list[str]) -> None: 'ensemble_binary': config.Eval( metrics={ 'brier_score': metrics.EnsembleBrierScore( - ensemble_dim=ENSEMBLE_DIM.value, threshold=threshold_list + ensemble_dim=ENSEMBLE_DIM.value, thresholds=threshold_list ), 'debiased_brier_score': metrics.DebiasedEnsembleBrierScore( - ensemble_dim=ENSEMBLE_DIM.value, threshold=threshold_list + ensemble_dim=ENSEMBLE_DIM.value, thresholds=threshold_list ), 'ignorance_score': metrics.EnsembleIgnoranceScore( - ensemble_dim=ENSEMBLE_DIM.value, threshold=threshold_list + ensemble_dim=ENSEMBLE_DIM.value, thresholds=threshold_list ), }, regions=regions, @@ -583,6 +583,26 @@ def main(argv: list[str]) -> None: probabilistic_climatology_hour_interval=PROBABILISTIC_CLIMATOLOGY_HOUR_INTERVAL.value, output_format='zarr', ), + 'ensemble_binary_spatial': config.Eval( + metrics={ + 'brier_score': metrics.SpatialEnsembleBrierScore( + ensemble_dim=ENSEMBLE_DIM.value, thresholds=threshold_list + ), + 'debiased_brier_score': metrics.SpatialDebiasedEnsembleBrierScore( + ensemble_dim=ENSEMBLE_DIM.value, thresholds=threshold_list + ), + 'ignorance_score': metrics.SpatialEnsembleIgnoranceScore( + ensemble_dim=ENSEMBLE_DIM.value, thresholds=threshold_list + ), + }, + against_analysis=False, + derived_variables=derived_variables, + evaluate_probabilistic_climatology=EVALUATE_PROBABILISTIC_CLIMATOLOGY.value, + probabilistic_climatology_start_year=PROBABILISTIC_CLIMATOLOGY_START_YEAR.value, + probabilistic_climatology_end_year=PROBABILISTIC_CLIMATOLOGY_END_YEAR.value, + probabilistic_climatology_hour_interval=PROBABILISTIC_CLIMATOLOGY_HOUR_INTERVAL.value, + output_format='zarr', + ), 'probabilistic_spatial_histograms': config.Eval( metrics={ 'rank_histogram': metrics.RankHistogram( @@ -609,10 +629,10 @@ def main(argv: list[str]) -> None: 'gaussian_binary': config.Eval( metrics={ 'brier_score': metrics.GaussianBrierScore( - threshold=threshold_list + thresholds=threshold_list ), 'ignorance_score': metrics.GaussianIgnoranceScore( - threshold=threshold_list + thresholds=threshold_list ), }, against_analysis=False, diff --git a/weatherbench2/metrics.py b/weatherbench2/metrics.py index a6cb92f..9c29238 100644 --- a/weatherbench2/metrics.py +++ b/weatherbench2/metrics.py @@ -938,7 +938,65 @@ def compute_chunk( @dataclasses.dataclass -class GaussianBrierScore(Metric): +class ThresholdMetric(Metric): + """Base class for metrics based on thresholds.""" + + thresholds: Sequence[thresholds.Threshold] + + def _map_over_thresholds( + self, + calculate_score: t.Callable[ + [xr.Dataset, xr.Dataset, xr.Dataset], xr.Dataset + ], + forecast: xr.Dataset, + truth: xr.Dataset, + region: t.Optional[Region], + skipna: bool, + spatial_agg: bool, + ) -> xr.Dataset: + """Map a function over all thresholds.""" + scores = [] + for threshold in self.thresholds: + threshold_ds = threshold.compute(truth) + score = calculate_score(forecast, truth, threshold_ds) + if spatial_agg: + score = _spatial_average(score, region=region, skipna=skipna) + scores.append(score.expand_dims(dim={"quantile": [threshold.quantile]})) + threshold_method = type(self.thresholds[0]).__name__ + return xr.concat(scores, dim="quantile").assign_attrs( + threshold_method=threshold_method + ) + + +def _compute_gaussian_brier_score( + forecast: xr.Dataset, + truth: xr.Dataset, + threshold: xr.Dataset, +) -> xr.Dataset: + """Computes the Brier score for a Gaussian distribution.""" + truth_probability = xr.where(truth > threshold, 1.0, 0.0) + + var_list = [] + exceedance_probability = {} + for var in forecast.keys(): + if f"{var}_std" in forecast.keys(): + var_list.append(var) + + for var_name in var_list: + std = forecast[f"{var_name}_std"] + norm_threshold = (threshold[var_name] - forecast[var_name]) / std + exceedance_probability[var_name] = 1 - xr.apply_ufunc( + stats.norm.cdf, norm_threshold.load() + ) + + forecast_probability = xr.Dataset( + exceedance_probability, coords=forecast.coords + ) + return (forecast_probability - truth_probability) ** 2 + + +@dataclasses.dataclass +class GaussianBrierScore(ThresholdMetric): """Brier score of a Gaussian forecast for a given binary threshold. The Brier score is computed based on the forecast probability of exceedance of @@ -960,8 +1018,6 @@ class GaussianBrierScore(Metric): Spatially averaged Brier score for a Gaussian distribution. """ - threshold: t.Union[thresholds.Threshold, Sequence[thresholds.Threshold]] - def compute_chunk( self, forecast: xr.Dataset, @@ -969,53 +1025,44 @@ def compute_chunk( region: t.Optional[Region] = None, skipna: bool = False, ) -> xr.Dataset: + return self._map_over_thresholds( + _compute_gaussian_brier_score, + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ) - if isinstance(self.threshold, thresholds.Threshold): - threshold_seq = [self.threshold] - threshold_method = type(self.threshold).__name__ - else: - threshold_seq = self.threshold - threshold_method = type(self.threshold[0]).__name__ - - brier_scores = [] - for threshold in threshold_seq: - quantile = threshold.quantile - threshold = threshold.compute(truth) - truth_probability = xr.where(truth > threshold, 1.0, 0.0) - - var_list = [] - exceedance_probability = {} - for var in forecast.keys(): - if f"{var}_std" in forecast.keys(): - var_list.append(var) - - for var_name in var_list: - norm_threshold = (threshold[var_name] - forecast[var_name]) / forecast[ - f"{var_name}_std" - ] - exceedance_probability[var_name] = 1 - xr.apply_ufunc( - stats.norm.cdf, norm_threshold.load() - ) - forecast_probability = xr.Dataset( - exceedance_probability, coords=forecast.coords - ) +def _compute_gaussian_ignorance_score( + forecast: xr.Dataset, + truth: xr.Dataset, + threshold: xr.Dataset, +) -> xr.Dataset: + """Computes the Ignorance score for a Gaussian distribution.""" + truth_probability = xr.where(truth > threshold, 1.0, 0.0) - brier_scores.append( - _spatial_average( - (forecast_probability - truth_probability) ** 2, - region=region, - skipna=skipna, - ).expand_dims(dim={"quantile": [quantile]}) - ) + log_realized_probability = {} + var_list = [var for var in forecast.keys() if f"{var}_std" in forecast.keys()] - return xr.merge(brier_scores).assign_attrs( - threshold_method=threshold_method + for var_name in var_list: + norm_threshold = (threshold[var_name] - forecast[var_name]) / forecast[ + f"{var_name}_std" + ] + cdf_value = xr.apply_ufunc(stats.norm.cdf, norm_threshold.load()) + log_realized_probability[var_name] = -xr.where( + truth_probability[var_name], + xr.apply_ufunc(np.log, 1 - cdf_value), + xr.apply_ufunc(np.log, cdf_value), ) + ignorance_score = xr.Dataset(log_realized_probability, coords=forecast.coords) + return ignorance_score + @dataclasses.dataclass -class GaussianIgnoranceScore(Metric): +class GaussianIgnoranceScore(ThresholdMetric): """Ignorance score of a Gaussian forecast for a given binary threshold. The ignorance or logarithmic score is computed based on the forecast @@ -1034,8 +1081,6 @@ class GaussianIgnoranceScore(Metric): Spatially averaged ignorance score for a Gaussian distribution. """ - threshold: t.Union[thresholds.Threshold, Sequence[thresholds.Threshold]] - def compute_chunk( self, forecast: xr.Dataset, @@ -1044,53 +1089,38 @@ def compute_chunk( skipna: bool = False, ) -> xr.Dataset: - if isinstance(self.threshold, thresholds.Threshold): - threshold_seq = [self.threshold] - threshold_method = type(self.threshold).__name__ - else: - threshold_seq = self.threshold - threshold_method = type(self.threshold[0]).__name__ - - ignorance_scores = [] - for threshold in threshold_seq: - quantile = threshold.quantile - threshold = threshold.compute(truth) - truth_probability = xr.where(truth > threshold, 1.0, 0.0) - - var_list = [] - log_realized_probability = {} - for var in forecast.keys(): - if f"{var}_std" in forecast.keys(): - var_list.append(var) - - for var_name in var_list: - norm_threshold = (threshold[var_name] - forecast[var_name]) / forecast[ - f"{var_name}_std" - ] - cdf_value = xr.apply_ufunc(stats.norm.cdf, norm_threshold.load()) - log_realized_probability[var_name] = -xr.where( - truth_probability[var_name], - xr.apply_ufunc(np.log, 1 - cdf_value), - xr.apply_ufunc(np.log, cdf_value), - ) + return self._map_over_thresholds( + _compute_gaussian_ignorance_score, + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ) - ignorance_score = xr.Dataset( - log_realized_probability, coords=forecast.coords - ) - ignorance_scores.append( - _spatial_average( - ignorance_score, region=region, skipna=skipna - ).expand_dims(dim={"quantile": [quantile]}) - ) +def _compute_gaussian_rps_part( + forecast: xr.Dataset, + truth: xr.Dataset, + threshold: xr.Dataset, +) -> xr.Dataset: + """Computes the Ranked Probability Score for a Gaussian distribution.""" + truth_ecdf = xr.where(truth < threshold, 1.0, 0.0) - return xr.merge(ignorance_scores).assign_attrs( - threshold_method=threshold_method - ) + var_list = [var for var in forecast.keys() if f"{var}_std" in forecast.keys()] + + cdf_values = {} + for var_name in var_list: + std = forecast[f"{var_name}_std"] + norm_threshold = (threshold[var_name] - forecast[var_name]) / std + cdf_values[var_name] = xr.apply_ufunc(stats.norm.cdf, norm_threshold.load()) + + forecast_cdf = xr.Dataset(cdf_values, coords=forecast.coords) + return (forecast_cdf - truth_ecdf) ** 2 @dataclasses.dataclass -class GaussianRPS(Metric): +class GaussianRPS(ThresholdMetric): """Ranked probability score of a Gaussian forecast for a given quantization. The ranked probability score (RPS) is computed based on the forecast and @@ -1110,8 +1140,6 @@ class GaussianRPS(Metric): Spatially averaged RPS for a Gaussian distribution. """ - thresholds: Sequence[thresholds.Threshold] - def compute_chunk( self, forecast: xr.Dataset, @@ -1120,31 +1148,14 @@ def compute_chunk( skipna: bool = False, ) -> xr.Dataset: - var_list = [] - for var in forecast.keys(): - if f"{var}_std" in forecast.keys(): - var_list.append(var) - - rps_per_threshold = [] - threshold_list = [t.compute(truth) for t in self.thresholds] - for threshold in threshold_list: - truth_ecdf = xr.where(truth < threshold, 1.0, 0.0) - - cdf_values = {} - for var_name in var_list: - norm_threshold = (threshold[var_name] - forecast[var_name]) / forecast[ - f"{var_name}_std" - ] - cdf_values[var_name] = xr.apply_ufunc( - stats.norm.cdf, norm_threshold.load() - ) - - forecast_cdf = xr.Dataset(cdf_values, coords=forecast.coords) - rps_per_threshold.append((forecast_cdf - truth_ecdf) ** 2) - - return _spatial_average( - sum(rps_per_threshold), region=region, skipna=skipna - ) + return self._map_over_thresholds( + _compute_gaussian_rps_part, + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ).sum(dim="quantile") @dataclasses.dataclass @@ -1510,86 +1521,45 @@ def compute_chunk( # components, as a sort of probabilistic variant of WindVectorMSE. -@dataclasses.dataclass -class _BaseEnsembleBrierScore(EnsembleMetric): - """Base class for [Debiased]EnsembleBrierScore.""" - - def __init__( - self, - threshold: t.Union[thresholds.Threshold, Sequence[thresholds.Threshold]], - ensemble_dim: str = REALIZATION, - ): - """Initializes a _BaseEnsembleBrierScore. - - Args: - threshold: Threshold used to binarize predictions and targets. - ensemble_dim: Dimension indexing ensemble member. - """ - super().__init__(ensemble_dim=ensemble_dim) - self.threshold = threshold - - def _compute_chunk_impl( - self, - debias: bool, - forecast: xr.Dataset, - truth: xr.Dataset, - region: t.Optional[Region], - skipna: bool, - ) -> xr.Dataset: - """Common implementation of compute_chunk.""" - - if isinstance(self.threshold, thresholds.Threshold): - threshold_seq = [self.threshold] - threshold_method = type(self.threshold).__name__ - else: - threshold_seq = self.threshold - threshold_method = type(self.threshold[0]).__name__ - - brier_scores = [] - for threshold in threshold_seq: - quantile = threshold.quantile - threshold = threshold.compute(truth) - # Notice we allow NaN in truth/forecast probabilities, then skipna during - # computation of BrierScore (which is really just an MSE over the - # probabilities). - truth_probability = xr.where( - truth.isnull(), - np.nan, - xr.where(truth > threshold, 1.0, 0.0), - ) - forecast_probability = xr.where( - forecast.isnull(), - np.nan, - xr.where(forecast > threshold, 1.0, 0.0), - ) - if debias: - mse_of_probabilities = _debiased_ensemble_mean_mse( - forecast_probability, - truth_probability, - self.ensemble_dim, - skipna=skipna, - ) - else: - mse_of_probabilities = ( - forecast_probability.mean(self.ensemble_dim, skipna=skipna) - - truth_probability - ) ** 2 - - brier_scores.append( - _spatial_average( - mse_of_probabilities, - region=region, - skipna=skipna, - ).expand_dims(dim={"quantile": [quantile]}) - ) - - return xr.merge(brier_scores).assign_attrs( - threshold_method=threshold_method +def _compute_brier_score( + forecast: xr.Dataset, + truth: xr.Dataset, + threshold: xr.Dataset, + ensemble_dim: str, + debias: bool, + skipna: bool, +) -> xr.Dataset: + """Compute the Brier score for a single threshold.""" + # Notice we allow NaN in truth/forecast probabilities, then skipna during + # computation of BrierScore (which is really just an MSE over the + # probabilities). + truth_probability = xr.where( + truth.isnull(), + np.nan, + xr.where(truth > threshold, 1.0, 0.0), + ) + forecast_probability = xr.where( + forecast.isnull(), + np.nan, + xr.where(forecast > threshold, 1.0, 0.0), + ) + if debias: + mse_of_probabilities = _debiased_ensemble_mean_mse( + forecast_probability, + truth_probability, + ensemble_dim, + skipna=skipna, ) + else: + mse_of_probabilities = ( + forecast_probability.mean(ensemble_dim, skipna=skipna) + - truth_probability + ) ** 2 + return mse_of_probabilities @dataclasses.dataclass -class EnsembleBrierScore(_BaseEnsembleBrierScore): +class EnsembleBrierScore(EnsembleMetric, ThresholdMetric): """Brier score of an ensemble forecast for a given binary threshold. The Brier score is computed based on the forecast probability of exceedance of @@ -1619,18 +1589,31 @@ class EnsembleBrierScore(_BaseEnsembleBrierScore): Score, DOI: https://doi.org/10.1175/WAF1034.1 """ - def __init__( + def compute_chunk( self, - threshold: t.Union[thresholds.Threshold, Sequence[thresholds.Threshold]], - ensemble_dim: str = REALIZATION, - ): - """Initializes an EnsembleBrierScore. + forecast: xr.Dataset, + truth: xr.Dataset, + region: t.Optional[Region] = None, + skipna: bool = False, + ) -> xr.Dataset: + return self._map_over_thresholds( + functools.partial( + _compute_brier_score, + ensemble_dim=self.ensemble_dim, + debias=False, + skipna=skipna, + ), + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ) - Args: - threshold: Threshold used to binarize predictions and targets. - ensemble_dim: Dimension indexing ensemble member. - """ - super().__init__(threshold=threshold, ensemble_dim=ensemble_dim) + +@dataclasses.dataclass +class SpatialEnsembleBrierScore(EnsembleMetric, ThresholdMetric): + """Spatial map of ensemble Brier score.""" def compute_chunk( self, @@ -1639,17 +1622,23 @@ def compute_chunk( region: t.Optional[Region] = None, skipna: bool = False, ) -> xr.Dataset: - return self._compute_chunk_impl( - debias=False, - forecast=forecast, - truth=truth, + return self._map_over_thresholds( + functools.partial( + _compute_brier_score, + ensemble_dim=self.ensemble_dim, + debias=False, + skipna=skipna, + ), + forecast, + truth, region=region, skipna=skipna, + spatial_agg=False, ) @dataclasses.dataclass -class DebiasedEnsembleBrierScore(_BaseEnsembleBrierScore): +class DebiasedEnsembleBrierScore(EnsembleMetric, ThresholdMetric): """Debiased Brier score of an ensemble forecast for a given binary threshold. The Brier score is computed based on the forecast probability of exceedance of @@ -1682,18 +1671,31 @@ class DebiasedEnsembleBrierScore(_BaseEnsembleBrierScore): Score, DOI: https://doi.org/10.1175/WAF1034.1 """ - def __init__( + def compute_chunk( self, - threshold: t.Union[thresholds.Threshold, Sequence[thresholds.Threshold]], - ensemble_dim: str = REALIZATION, - ): - """Initializes a DebiasedEnsembleBrierScore. + forecast: xr.Dataset, + truth: xr.Dataset, + region: t.Optional[Region] = None, + skipna: bool = False, + ) -> xr.Dataset: + return self._map_over_thresholds( + functools.partial( + _compute_brier_score, + ensemble_dim=self.ensemble_dim, + debias=True, + skipna=skipna, + ), + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ) - Args: - threshold: Threshold used to binarize predictions and targets. - ensemble_dim: Dimension indexing ensemble member. - """ - super().__init__(threshold=threshold, ensemble_dim=ensemble_dim) + +@dataclasses.dataclass +class SpatialDebiasedEnsembleBrierScore(EnsembleMetric, ThresholdMetric): + """Spatial map of ensemble debiased Brier score.""" def compute_chunk( self, @@ -1702,17 +1704,44 @@ def compute_chunk( region: t.Optional[Region] = None, skipna: bool = False, ) -> xr.Dataset: - return self._compute_chunk_impl( - debias=True, - forecast=forecast, - truth=truth, + return self._map_over_thresholds( + functools.partial( + _compute_brier_score, + ensemble_dim=self.ensemble_dim, + debias=True, + skipna=skipna, + ), + forecast, + truth, region=region, skipna=skipna, + spatial_agg=False, ) +def _compute_ignorance_score( + forecast: xr.Dataset, + truth: xr.Dataset, + threshold: xr.Dataset, + ensemble_dim: str, + skipna: bool, +) -> xr.Dataset: + """Compute the Ignorance score for a single threshold.""" + truth_probability = xr.where(truth > threshold, 1.0, 0.0) + forecast_probability = xr.where(forecast > threshold, 1.0, 0.0) + ensemble_forecast_probability = forecast_probability.mean( + ensemble_dim, skipna=skipna + ) + ignorance_score = -xr.where( + truth_probability, + xr.apply_ufunc(np.log, ensemble_forecast_probability), + xr.apply_ufunc(np.log, 1 - ensemble_forecast_probability), + ) + return ignorance_score + + @dataclasses.dataclass -class EnsembleIgnoranceScore(EnsembleMetric): +class EnsembleIgnoranceScore(EnsembleMetric, ThresholdMetric): """Ignorance score of an ensemble forecast for a given binary threshold. The ignorance or logarithmic score is computed based on the forecast @@ -1725,19 +1754,31 @@ class EnsembleIgnoranceScore(EnsembleMetric): DOI: https://doi.org/10.1175/2009MWR2945.1 """ - def __init__( + def compute_chunk( self, - threshold: t.Union[thresholds.Threshold, Sequence[thresholds.Threshold]], - ensemble_dim: str = REALIZATION, - ): - """Initializes an EnsembleIgnoranceScore. + forecast: xr.Dataset, + truth: xr.Dataset, + region: t.Optional[Region] = None, + skipna: bool = False, + ) -> xr.Dataset: - Args: - threshold: Threshold used to binarize predictions and targets. - ensemble_dim: Dimension indexing ensemble member. - """ - super().__init__(ensemble_dim=ensemble_dim) - self.threshold = threshold + return self._map_over_thresholds( + functools.partial( + _compute_ignorance_score, + ensemble_dim=self.ensemble_dim, + skipna=skipna, + ), + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ) + + +@dataclasses.dataclass +class SpatialEnsembleIgnoranceScore(EnsembleMetric, ThresholdMetric): + """Spatial map of ensemble ignorance score.""" def compute_chunk( self, @@ -1747,43 +1788,36 @@ def compute_chunk( skipna: bool = False, ) -> xr.Dataset: - if isinstance(self.threshold, thresholds.Threshold): - threshold_seq = [self.threshold] - threshold_method = type(self.threshold).__name__ - else: - threshold_seq = self.threshold - threshold_method = type(self.threshold[0]).__name__ - - ignorance_scores = [] - for threshold in threshold_seq: - quantile = threshold.quantile - threshold = threshold.compute(truth) - truth_probability = xr.where(truth > threshold, 1.0, 0.0) - forecast_probability = xr.where(forecast > threshold, 1.0, 0.0) - ensemble_forecast_probability = forecast_probability.mean( - self.ensemble_dim, - skipna=skipna, - ) - ignorance_score = -xr.where( - truth_probability, - xr.apply_ufunc(np.log, ensemble_forecast_probability), - xr.apply_ufunc(np.log, 1 - ensemble_forecast_probability), - ) - ignorance_scores.append( - _spatial_average( - ignorance_score, - region=region, - skipna=skipna, - ).expand_dims(dim={"quantile": [quantile]}) - ) - - return xr.merge(ignorance_scores).assign_attrs( - threshold_method=threshold_method + return self._map_over_thresholds( + functools.partial( + _compute_ignorance_score, + ensemble_dim=self.ensemble_dim, + skipna=skipna, + ), + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=False, ) +def _compute_rps_part( + forecast: xr.Dataset, + truth: xr.Dataset, + threshold: xr.Dataset, + ensemble_dim: str, + skipna: bool, +) -> xr.Dataset: + """Compute the contribution to RPS for a single threshold.""" + truth_ecdf = xr.where(truth < threshold, 1.0, 0.0) + forecast_ecdf = xr.where(forecast < threshold, 1.0, 0.0) + ensemble_forecast_ecdf = forecast_ecdf.mean(ensemble_dim, skipna=skipna) + return (ensemble_forecast_ecdf - truth_ecdf) ** 2 + + @dataclasses.dataclass -class EnsembleRPS(EnsembleMetric): +class EnsembleRPS(EnsembleMetric, ThresholdMetric): """Ranked probability score of an ensemble forecast for a given quantization. The ranked probability score (RPS) is computed based on the forecast and @@ -1811,20 +1845,30 @@ class EnsembleRPS(EnsembleMetric): DOI: https://doi.org/10.1175/1520-0450(1969)008<0985:ASSFPF>2.0.CO;2 """ - def __init__( + def compute_chunk( self, - threshold: Sequence[thresholds.Threshold], - ensemble_dim: str = REALIZATION, - ): - """Initializes an EnsembleRPS. + forecast: xr.Dataset, + truth: xr.Dataset, + region: t.Optional[Region] = None, + skipna: bool = False, + ) -> xr.Dataset: + + result = self._map_over_thresholds( + functools.partial( + _compute_rps_part, ensemble_dim=self.ensemble_dim, skipna=skipna + ), + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=True, + ) + return result.sum("quantile") - Args: - threshold: A sequence of thresholds used to divide predictions and targets - categorically. - ensemble_dim: Dimension indexing ensemble member. - """ - super().__init__(ensemble_dim=ensemble_dim) - self.thresholds = threshold + +@dataclasses.dataclass +class SpatialEnsembleRPS(EnsembleMetric, ThresholdMetric): + """Spatial map of ensemble RPS.""" def compute_chunk( self, @@ -1833,23 +1877,18 @@ def compute_chunk( region: t.Optional[Region] = None, skipna: bool = False, ) -> xr.Dataset: - """Spatially averaged RPS of the ensemble forecast.""" - rps_per_threshold = [] - threshold_list = [t.compute(truth) for t in self.thresholds] - for threshold in threshold_list: - - truth_ecdf = xr.where(truth < threshold, 1.0, 0.0) - forecast_ecdf = xr.where(forecast < threshold, 1.0, 0.0) - ensemble_forecast_ecdf = forecast_ecdf.mean( - self.ensemble_dim, - skipna=skipna, - ) - rps_per_threshold.append((ensemble_forecast_ecdf - truth_ecdf) ** 2) - - return _spatial_average( - sum(rps_per_threshold), region=region, skipna=skipna + result = self._map_over_thresholds( + functools.partial( + _compute_rps_part, ensemble_dim=self.ensemble_dim, skipna=skipna + ), + forecast, + truth, + region=region, + skipna=skipna, + spatial_agg=False, ) + return result.sum("quantile") @dataclasses.dataclass diff --git a/weatherbench2/metrics_test.py b/weatherbench2/metrics_test.py index 1ca88a1..fda07f1 100644 --- a/weatherbench2/metrics_test.py +++ b/weatherbench2/metrics_test.py @@ -402,7 +402,7 @@ def test_gaussian_brier_score(self, error, expected_1, expected_2): threshold = thresholds.GaussianQuantileThreshold( climatology=climatology, quantile=0.8 ) - result = metrics.GaussianBrierScore(threshold).compute(forecast, truth) + result = metrics.GaussianBrierScore([threshold]).compute(forecast, truth) expected_arr = np.array([[expected_1, expected_1]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, rtol=1e-4 @@ -416,7 +416,7 @@ def test_gaussian_brier_score(self, error, expected_1, expected_2): threshold = thresholds.QuantileThreshold( climatology=climatology, quantile=0.8 ) - result = metrics.GaussianBrierScore(threshold).compute(forecast, truth) + result = metrics.GaussianBrierScore([threshold]).compute(forecast, truth) expected_arr = np.array([[expected_2, expected_2]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, rtol=1e-4 @@ -458,7 +458,9 @@ def test_gaussian_ignorance_score(self, error, expected): threshold = thresholds.GaussianQuantileThreshold( climatology=climatology, quantile=0.8 ) - result = metrics.GaussianIgnoranceScore(threshold).compute(forecast, truth) + result = metrics.GaussianIgnoranceScore([threshold]).compute( + forecast, truth + ) expected_arr = np.array([[expected, expected]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, rtol=1e-4 @@ -963,7 +965,7 @@ def test_ensemble_brier_score(self, error, ens_delta, expected): threshold = thresholds.GaussianQuantileThreshold( climatology=climatology, quantile=0.2 ) - result = metrics.EnsembleBrierScore(threshold).compute(forecast, truth) + result = metrics.EnsembleBrierScore([threshold]).compute(forecast, truth) expected_arr = np.array([[expected, expected]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, rtol=1e-4 @@ -1017,7 +1019,7 @@ def test_nan_propagates_to_output_unless_skipna(self, skipna): with self.subTest('forecast has nan'): # When forecast has nan in prediction_timedelta, only that timedelta will # be NaN. - result = metrics.EnsembleBrierScore(threshold).compute( + result = metrics.EnsembleBrierScore([threshold]).compute( forecast_with_nan, truth, skipna=skipna, @@ -1034,7 +1036,7 @@ def test_nan_propagates_to_output_unless_skipna(self, skipna): with self.subTest('truth has nan'): # When truth has nan, the final average over times means the entire # score is NaN. - result = metrics.EnsembleBrierScore(threshold).compute( + result = metrics.EnsembleBrierScore([threshold]).compute( forecast, truth_with_nan, skipna=skipna, @@ -1080,14 +1082,14 @@ def test_versus_large_ensemble_and_ensure_skipna_works(self): quantile=quantile, ) - bs_large_ensemble = metrics.EnsembleBrierScore(threshold).compute( + bs_large_ensemble = metrics.EnsembleBrierScore([threshold]).compute( forecast, truth ) - bs_small_ensemble = metrics.EnsembleBrierScore(threshold).compute( + bs_small_ensemble = metrics.EnsembleBrierScore([threshold]).compute( small_ensemble_forecast, truth ) bs_debiased_small_ensemble = metrics.DebiasedEnsembleBrierScore( - threshold + [threshold] ).compute(small_ensemble_forecast, truth) # Get some variants using a bit of NaN values @@ -1102,13 +1104,13 @@ def test_versus_large_ensemble_and_ensure_skipna_works(self): small_ensemble_forecast, frac_nan=frac_nan, seed=0 ) truth_w_nan = test_utils.insert_nan(truth, frac_nan=frac_nan, seed=1) - bs_small_ensemble_w_nan = metrics.EnsembleBrierScore(threshold).compute( + bs_small_ensemble_w_nan = metrics.EnsembleBrierScore([threshold]).compute( small_ensemble_forecast_w_nan, truth_w_nan, skipna=True, ) bs_debiased_small_ensemble_w_nan = metrics.DebiasedEnsembleBrierScore( - threshold + [threshold] ).compute(small_ensemble_forecast_w_nan, truth_w_nan, skipna=True) # Make sure the test is not trivial by showing that without debiasing we get @@ -1177,7 +1179,9 @@ def test_ensemble_ignorance_score(self, error, expected): threshold = thresholds.GaussianQuantileThreshold( climatology=climatology, quantile=0.2 ) - result = metrics.EnsembleIgnoranceScore(threshold).compute(forecast, truth) + result = metrics.EnsembleIgnoranceScore([threshold]).compute( + forecast, truth + ) expected_arr = np.array([[expected, expected]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, rtol=1e-4