Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No public description #189

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion scripts/compute_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@
'If empty, compute on all data_vars of --input_path'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)
FANOUT = flags.DEFINE_integer(
'fanout',
None,
Expand Down Expand Up @@ -138,7 +146,9 @@ def main(argv: list[str]):

(
chunked
| xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value)
| xbeam.Mean(
AVERAGING_DIMS.value, skipna=SKIPNA.value, fanout=FANOUT.value
)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
Expand Down
10 changes: 9 additions & 1 deletion scripts/compute_ensemble_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@
' all variables are selected.'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)


# pylint: disable=expression-not-assigned
Expand Down Expand Up @@ -123,7 +131,7 @@ def main(argv: list[str]):
split_vars=True,
num_threads=NUM_THREADS.value,
)
| xbeam.Mean(REALIZATION_NAME.value, skipna=False)
| xbeam.Mean(REALIZATION_NAME.value, skipna=SKIPNA.value)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
Expand Down
15 changes: 13 additions & 2 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@
' "2m_temperature"}'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)
PRESSURE_LEVEL_SUFFIXES = flags.DEFINE_bool(
'pressure_level_suffixes',
False,
Expand Down Expand Up @@ -630,14 +638,17 @@ def main(argv: list[str]) -> None:
eval_configs,
runner=RUNNER.value,
input_chunks=INPUT_CHUNKS.value,
skipna=SKIPNA.value,
fanout=FANOUT.value,
num_threads=NUM_THREADS.value,
argv=argv,
)
else:
evaluation.evaluate_in_memory(data_config, eval_configs)
evaluation.evaluate_in_memory(
data_config, eval_configs, skipna=SKIPNA.value
)


if __name__ == '__main__':
app.run(main)
flags.mark_flag_as_required('output_path')
flags.mark_flags_as_required(['output_path', 'obs_path'])
34 changes: 25 additions & 9 deletions scripts/resample_in_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@
' use the last time in --input_path.'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)
WORKING_CHUNKS = flag_utils.DEFINE_chunks(
'working_chunks',
'',
Expand Down Expand Up @@ -182,6 +190,7 @@ def resample_in_time_chunk(
min_vars: list[str],
max_vars: list[str],
add_mean_suffix: bool,
skipna: bool = False,
) -> tuple[xbeam.Key, xr.Dataset]:
"""Resample a data chunk in time and return a requested time statistic.

Expand All @@ -196,6 +205,8 @@ def resample_in_time_chunk(
max_vars: Variables to compute the max of.
add_mean_suffix: Whether to add a "_mean" suffix to variables after
computing the mean.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.

Returns:
The resampled data chunk and its key.
Expand All @@ -207,21 +218,23 @@ def resample_in_time_chunk(
for chunk_var in chunk.data_vars:
if chunk_var in mean_vars:
rsmp_chunks.append(
resample_in_time_core(chunk, method, period, 'mean').rename(
resample_in_time_core(
chunk, method, period, 'mean', skipna=skipna
).rename(
{chunk_var: f'{chunk_var}_mean' if add_mean_suffix else chunk_var}
)
)
if chunk_var in min_vars:
rsmp_chunks.append(
resample_in_time_core(chunk, method, period, 'min').rename(
{chunk_var: f'{chunk_var}_min'}
)
resample_in_time_core(
chunk, method, period, 'min', skipna=skipna
).rename({chunk_var: f'{chunk_var}_min'})
)
if chunk_var in max_vars:
rsmp_chunks.append(
resample_in_time_core(chunk, method, period, 'max').rename(
{chunk_var: f'{chunk_var}_max'}
)
resample_in_time_core(
chunk, method, period, 'max', skipna=skipna
).rename({chunk_var: f'{chunk_var}_max'})
)

return rsmp_key, xr.merge(rsmp_chunks)
Expand All @@ -232,6 +245,7 @@ def resample_in_time_core(
method: str,
period: pd.Timedelta,
statistic: str,
skipna: bool,
) -> t.Union[xr.Dataset, xr.DataArray]:
"""Core call to xarray resample or rolling."""
if method == 'rolling':
Expand All @@ -245,12 +259,12 @@ def resample_in_time_core(
{TIME_DIM.value: period // delta_t}, center=False, min_periods=None
),
statistic,
)(skipna=False)
)(skipna=skipna)
elif method == 'resample':
return getattr(
chunk.resample({TIME_DIM.value: period}, label='left'),
statistic,
)(skipna=False)
)(skipna=skipna)
else:
raise ValueError(f'Unhandled {method=}')

Expand Down Expand Up @@ -301,6 +315,7 @@ def main(argv: abc.Sequence[str]) -> None:
METHOD.value,
period,
statistic='mean',
skipna=SKIPNA.value,
)[TIME_DIM.value]
else:
rsmp_times = ds[TIME_DIM.value]
Expand Down Expand Up @@ -369,6 +384,7 @@ def main(argv: abc.Sequence[str]) -> None:
min_vars=min_vars,
max_vars=max_vars,
add_mean_suffix=ADD_MEAN_SUFFIX.value,
skipna=SKIPNA.value,
)
)
| 'RechunkToOutputChunks'
Expand Down
35 changes: 27 additions & 8 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _metric_and_region_loop(
forecast: xr.Dataset,
truth: xr.Dataset,
eval_config: config.Eval,
skipna: bool,
compute_chunk: bool = False,
) -> xr.Dataset:
"""Compute metric results looping over metrics and regions in eval config."""
Expand All @@ -412,16 +413,18 @@ def _metric_and_region_loop(
region_dim = xr.DataArray(
[region_name], coords={'region': [region_name]}
)
tmp_result = eval_fn(forecast=forecast, truth=truth, region=region)
tmp_result = eval_fn(
forecast=forecast, truth=truth, region=region, skipna=skipna
)
tmp_results.append(
tmp_result.expand_dims({'metric': metric_dim, 'region': region_dim})
)
logging.info(f'Logging region done: {region_name}')
result = xr.concat(tmp_results, 'region')
else:
result = eval_fn(forecast=forecast, truth=truth).expand_dims(
{'metric': metric_dim}
)
result = eval_fn(
forecast=forecast, truth=truth, skipna=skipna
).expand_dims({'metric': metric_dim})
results.append(result)
logging.info(f'Logging metric done: {name}')
results = xr.merge(results)
Expand All @@ -432,6 +435,7 @@ def _evaluate_all_metrics(
eval_name: str,
eval_config: config.Eval,
data_config: config.Data,
skipna: bool,
) -> None:
"""Evaluate a set of eval metrics in memory."""
forecast, truth, climatology = open_forecast_and_truth_datasets(
Expand Down Expand Up @@ -463,7 +467,7 @@ def _evaluate_all_metrics(
if data_config.by_init:
truth = truth.sel(time=forecast.valid_time)

results = _metric_and_region_loop(forecast, truth, eval_config)
results = _metric_and_region_loop(forecast, truth, eval_config, skipna=skipna)

logging.info(f'Logging Evaluation complete:\n{results}')

Expand All @@ -475,6 +479,7 @@ def _evaluate_all_metrics(
def evaluate_in_memory(
data_config: config.Data,
eval_configs: dict[str, config.Eval],
skipna: bool = False,
) -> None:
"""Run evaluation in memory.

Expand All @@ -498,9 +503,11 @@ def evaluate_in_memory(
Args:
data_config: config.Data instance.
eval_configs: Dictionary of config.Eval instances.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
"""
for eval_name, eval_config in eval_configs.items():
_evaluate_all_metrics(eval_name, eval_config, data_config)
_evaluate_all_metrics(eval_name, eval_config, data_config, skipna=skipna)


@dataclasses.dataclass
Expand Down Expand Up @@ -547,13 +554,17 @@ class _EvaluateAllMetrics(beam.PTransform):
eval_config: config.Eval instance.
data_config: config.Data instance.
input_chunks: Chunks to use for input files.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
fanout: Fanout parameter for Beam combiners.
num_threads: Number of threads for reading/writing files.
"""

eval_name: str
eval_config: config.Eval
data_config: config.Data
input_chunks: abc.Mapping[str, int]
skipna: bool
fanout: Optional[int] = None
num_threads: Optional[int] = None

Expand All @@ -565,7 +576,11 @@ def _evaluate_chunk(
forecast, truth = forecast_and_truth
logging.info(f'Logging _evaluate_chunk Key: {key}')
results = _metric_and_region_loop(
forecast, truth, self.eval_config, compute_chunk=True
forecast,
truth,
self.eval_config,
compute_chunk=True,
skipna=self.skipna,
)
dropped_dims = [dim for dim in key.offsets if dim not in results.dims]
result_key = key.with_offsets(**{dim: None for dim in dropped_dims})
Expand Down Expand Up @@ -709,7 +724,7 @@ def _evaluate(
forecast_pipeline |= 'TemporalMean' >> xbeam.Mean(
dim='init_time' if self.data_config.by_init else 'time',
fanout=self.fanout,
skipna=False,
skipna=self.skipna,
)

return forecast_pipeline
Expand All @@ -733,6 +748,7 @@ def evaluate_with_beam(
fanout: Optional[int] = None,
num_threads: Optional[int] = None,
argv: Optional[list[str]] = None,
skipna: bool = False,
) -> None:
"""Run evaluation with a Beam pipeline.

Expand Down Expand Up @@ -761,6 +777,8 @@ def evaluate_with_beam(
fanout: Beam CombineFn fanout.
num_threads: Number of threads to use for reading/writing data.
argv: Other arguments to pass into the Beam pipeline.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
"""

with beam.Pipeline(runner=runner, argv=argv) as root:
Expand All @@ -776,6 +794,7 @@ def evaluate_with_beam(
input_chunks,
fanout=fanout,
num_threads=num_threads,
skipna=skipna,
)
| f'save_{eval_name}'
>> _SaveOutputs(
Expand Down
Loading
Loading