From 23b0ef20fb77873606ae9208c0c7fe037a587e4b Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Tue, 23 Apr 2024 13:03:39 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 627480182 --- docs/source/command-line-scripts.md | 44 +++++- scripts/compute_averages.py | 17 +- scripts/compute_climatology.py | 16 +- scripts/compute_derived_variables.py | 17 +- scripts/compute_ensemble_mean.py | 19 ++- scripts/compute_statistical_moments.py | 10 +- scripts/compute_zonal_energy_spectrum.py | 17 +- scripts/convert_init_to_valid_time.py | 16 +- scripts/evaluate.py | 6 + scripts/expand_climatology.py | 10 +- scripts/regrid.py | 19 ++- scripts/resample_in_time.py | 39 ++++- scripts/resample_in_time_test.py | 145 ++++++++++++++--- scripts/slice_dataset.py | 192 +++++++++++++++++++++++ scripts/slice_dataset_test.py | 163 +++++++++++++++++++ weatherbench2/evaluation.py | 24 ++- weatherbench2/flag_utils.py | 70 ++++++++- 17 files changed, 775 insertions(+), 49 deletions(-) create mode 100644 scripts/slice_dataset.py create mode 100644 scripts/slice_dataset_test.py diff --git a/docs/source/command-line-scripts.md b/docs/source/command-line-scripts.md index 46ae364..164bb7a 100644 --- a/docs/source/command-line-scripts.md +++ b/docs/source/command-line-scripts.md @@ -429,10 +429,52 @@ _Command options_: * `--working_chunks`: Spatial chunk sizes to use during time downsampling, e.g., "longitude=10,latitude=10". They may not include "time". * `--beam_runner`: Beam runner. Use `DirectRunner` for local execution. +## Slice dataset +Slices a Zarr file containing an xarray Dataset, using `.sel` and `.isel`. + +``` +usage: slice_dataset.py [-h] + [--input_path INPUT_PATH] + [--output_path OUTPUT_PATH] + [--sel SEL] + [--isel ISEL] + [--drop_variables DROP_VARIABLES] + [--keep_variables KEEP_VARIABLES] + [--output_chunks OUTPUT_CHUNKS] + [--runner RUNNER] + +``` + +_Command options_: + +* `--input_path`: (required) Input Zarr path +* `--output_path`: (required) Output Zarr path +* `--sel`: Selection criteria, to pass to `xarray.Dataset.sel`. Passed as + key=value pairs, with key = `VARNAME_{start,stop,step}` +* `--isel`: Selection criteria, to pass to `xarray.Dataset.isel`. Passed as + key=value pairs, with key = `VARNAME_{start,stop,step}` +* `--drop_variables`: Comma delimited list of variables to drop. If empty, drop + no variables. +* `--keep_variables`: Comma delimited list of variables to keep. If empty, use + `--drop_variables` to determine which variables to keep. +* `--output_chunks`: Chunk sizes overriding input chunks. +* `--runner`: Beam runner. Use `DirectRunner` for local execution. + +*Example* + +```bash +python slice_dataset.py -- \ + --input_path=gs://weatherbench2/datasets/ens/2018-64x32_equiangular_with_poles_conservative.zarr \ + --output_path=PATH \ + --sel="prediction_timedelta_stop=15 days,latitude_start=-33.33,latitude_stop=33.33" \ + --isel="longitude_start=0,longitude_stop=180,longitude_step=40" \ + --keep_variables=geopotential,temperature +``` + ## Expand climatology `expand_climatology.py` takes a climatology dataset and expands it into a forecast-like format (`init_time` + `lead_time`). This is not currently used as `evaluation.py` is able to do this on-the-fly, reducing the number of intermediate steps. We still included the script here in case others find it useful. ## Init to valid time conversion -`compute_init_to_valid_time.py` converts a forecasts in init-time convention to valid-time convention. Since currently, we do all evaluation in the init-time format, this script is not used. \ No newline at end of file +`compute_init_to_valid_time.py` converts a forecasts in init-time convention to valid-time convention. Since currently, we do all evaluation in the init-time format, this script is not used. diff --git a/scripts/compute_averages.py b/scripts/compute_averages.py index 23f9088..9ca2460 100644 --- a/scripts/compute_averages.py +++ b/scripts/compute_averages.py @@ -86,6 +86,11 @@ None, help='Beam CombineFn fanout. Might be required for large dataset.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) # pylint: disable=expression-not-assigned @@ -120,7 +125,10 @@ def main(argv: list[str]): with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: chunked = root | xbeam.DatasetToChunks( - source_dataset, source_chunks, split_vars=True + source_dataset, + source_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, ) if weights is not None: @@ -131,7 +139,12 @@ def main(argv: list[str]): ( chunked | xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value) - | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks) + | xbeam.ChunksToZarr( + OUTPUT_PATH.value, + template, + target_chunks, + num_threads=NUM_THREADS.value, + ) ) diff --git a/scripts/compute_climatology.py b/scripts/compute_climatology.py index a32f296..6269d04 100644 --- a/scripts/compute_climatology.py +++ b/scripts/compute_climatology.py @@ -120,6 +120,11 @@ 'precipitation variable. In mm.' ), ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) class Quantile: @@ -330,6 +335,10 @@ def _compute_seeps(kv): if stat not in ['seeps', 'mean']: for var in raw_vars: if stat == 'quantile': + if not quantiles: + raise ValueError( + 'Cannot compute stat `quantile` without specifying --quantiles.' + ) quantile_dim = xr.DataArray( quantiles, name='quantile', dims=['quantile'] ) @@ -349,7 +358,10 @@ def _compute_seeps(kv): pcoll = ( root | xbeam.DatasetToChunks( - obs, input_chunks, split_vars=True, num_threads=16 + obs, + input_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, ) | 'RechunkIn' >> xbeam.Rechunk( # pytype: disable=wrong-arg-types @@ -412,7 +424,7 @@ def _compute_seeps(kv): OUTPUT_PATH.value, template=clim_template, zarr_chunks=output_chunks, - num_threads=16, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/compute_derived_variables.py b/scripts/compute_derived_variables.py index 25b6aae..30a49e2 100644 --- a/scripts/compute_derived_variables.py +++ b/scripts/compute_derived_variables.py @@ -116,6 +116,11 @@ MAX_MEM_GB = flags.DEFINE_integer( 'max_mem_gb', 1, help='Max memory for rechunking in GB.' ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -226,7 +231,12 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool: # so that with and without rechunking can be computed in parallel pcoll = ( root - | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False) + | xbeam.DatasetToChunks( + source_dataset, + source_chunks, + split_vars=False, + num_threads=NUM_THREADS.value, + ) | beam.MapTuple( lambda k, v: ( # pylint: disable=g-long-lambda k, @@ -274,7 +284,10 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool: # Combined _ = pcoll | xbeam.ChunksToZarr( - OUTPUT_PATH.value, template, source_chunks, num_threads=16 + OUTPUT_PATH.value, + template, + source_chunks, + num_threads=NUM_THREADS.value, ) diff --git a/scripts/compute_ensemble_mean.py b/scripts/compute_ensemble_mean.py index 736ea8d..5e6030a 100644 --- a/scripts/compute_ensemble_mean.py +++ b/scripts/compute_ensemble_mean.py @@ -61,6 +61,11 @@ '2020-12-31', help='ISO 8601 timestamp (inclusive) at which to stop evaluation', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) # pylint: disable=expression-not-assigned @@ -88,9 +93,19 @@ def main(argv: list[str]): with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: ( root - | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True) + | xbeam.DatasetToChunks( + source_dataset, + source_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, + ) | xbeam.Mean(REALIZATION_NAME.value, skipna=False) - | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks) + | xbeam.ChunksToZarr( + OUTPUT_PATH.value, + template, + target_chunks, + num_threads=NUM_THREADS.value, + ) ) diff --git a/scripts/compute_statistical_moments.py b/scripts/compute_statistical_moments.py index 31ab6bb..281bfbd 100644 --- a/scripts/compute_statistical_moments.py +++ b/scripts/compute_statistical_moments.py @@ -37,6 +37,11 @@ RECHUNK_ITEMSIZE = flags.DEFINE_integer( 'rechunk_itemsize', 4, help='Itemsize for rechunking.' ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) def moment_reduce( @@ -143,7 +148,9 @@ def main(argv: list[str]) -> None: with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: # Read - pcoll = root | xbeam.DatasetToChunks(obs, input_chunks, split_vars=True) + pcoll = root | xbeam.DatasetToChunks( + obs, input_chunks, split_vars=True, num_threads=NUM_THREADS.value + ) # Branches to compute statistical moments pcolls = [] @@ -174,6 +181,7 @@ def main(argv: list[str]) -> None: OUTPUT_PATH.value, template=output_template, zarr_chunks=output_chunks, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/compute_zonal_energy_spectrum.py b/scripts/compute_zonal_energy_spectrum.py index d1173ce..ee1595f 100644 --- a/scripts/compute_zonal_energy_spectrum.py +++ b/scripts/compute_zonal_energy_spectrum.py @@ -96,6 +96,11 @@ None, help='Beam CombineFn fanout. Might be required for large dataset.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -196,7 +201,12 @@ def main(argv: list[str]) -> None: with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: _ = ( root - | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False) + | xbeam.DatasetToChunks( + source_dataset, + source_chunks, + split_vars=False, + num_threads=NUM_THREADS.value, + ) | beam.MapTuple( lambda k, v: ( # pylint: disable=g-long-lambda k, @@ -207,7 +217,10 @@ def main(argv: list[str]) -> None: | beam.MapTuple(_strip_offsets) | xbeam.Mean(AVERAGING_DIMS.value, fanout=FANOUT.value) | xbeam.ChunksToZarr( - OUTPUT_PATH.value, template, output_chunks, num_threads=16 + OUTPUT_PATH.value, + template, + output_chunks, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/convert_init_to_valid_time.py b/scripts/convert_init_to_valid_time.py index 3e7f84c..446d61f 100644 --- a/scripts/convert_init_to_valid_time.py +++ b/scripts/convert_init_to_valid_time.py @@ -102,6 +102,11 @@ INPUT_PATH = flags.DEFINE_string('input_path', None, help='zarr inputs') OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='zarr outputs') RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) TIME = 'time' DELTA = 'prediction_timedelta' @@ -254,7 +259,9 @@ def main(argv: list[str]) -> None: source_ds.indexes[INIT], ) ) - p |= xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True) + p |= xarray_beam.DatasetToChunks( + source_ds, input_chunks, split_vars=True, num_threads=NUM_THREADS.value + ) if input_chunks != split_chunks: p |= xarray_beam.SplitChunks(split_chunks) p |= beam.FlatMapTuple( @@ -266,7 +273,12 @@ def main(argv: list[str]) -> None: p = (p, padding) | beam.Flatten() if input_chunks != split_chunks: p |= xarray_beam.ConsolidateChunks(output_chunks) - p |= xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks) + p |= xarray_beam.ChunksToZarr( + OUTPUT_PATH.value, + template, + output_chunks, + num_threads=NUM_THREADS.value, + ) if __name__ == '__main__': diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 153d968..be6fbd0 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -249,6 +249,11 @@ None, help='Beam CombineFn fanout. Might be required for large dataset.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write Zarr in parallel per worker.', +) def _wind_vector_error(err_type: str): @@ -623,6 +628,7 @@ def main(argv: list[str]) -> None: runner=RUNNER.value, input_chunks=INPUT_CHUNKS.value, fanout=FANOUT.value, + num_threads=NUM_THREADS.value, argv=argv, ) else: diff --git a/scripts/expand_climatology.py b/scripts/expand_climatology.py index 699e9c8..c330328 100644 --- a/scripts/expand_climatology.py +++ b/scripts/expand_climatology.py @@ -72,6 +72,11 @@ None, help='Desired integer chunk size. If not set, inferred from input chunks.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -149,7 +154,10 @@ def main(argv: list[str]) -> None: | beam.Reshuffle() | beam.FlatMap(select_climatology, climatology, times, base_chunks) | xbeam.ChunksToZarr( - OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks + OUTPUT_PATH.value, + template=template, + zarr_chunks=output_chunks, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/regrid.py b/scripts/regrid.py index 79c0e6f..54c070a 100644 --- a/scripts/regrid.py +++ b/scripts/regrid.py @@ -78,6 +78,11 @@ LONGITUDE_NAME = flags.DEFINE_string( 'longitude_name', 'longitude', help='Name of longitude dimension in dataset' ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -135,11 +140,21 @@ def main(argv): with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: _ = ( root - | xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True) + | xarray_beam.DatasetToChunks( + source_ds, + input_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, + ) | 'Regrid' >> beam.MapTuple(lambda k, v: (k, regridder.regrid_dataset(v))) | xarray_beam.ConsolidateChunks(output_chunks) - | xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks) + | xarray_beam.ChunksToZarr( + OUTPUT_PATH.value, + template, + output_chunks, + num_threads=NUM_THREADS.value, + ) ) diff --git a/scripts/resample_in_time.py b/scripts/resample_in_time.py index 192a547..dd0e0c9 100644 --- a/scripts/resample_in_time.py +++ b/scripts/resample_in_time.py @@ -70,7 +70,12 @@ 'method', 'resample', ['resample', 'rolling'], - help='Whether to resample to new times, or use a rolling window.', + help=( + 'Whether to resample to new times (spaced by --period), or use a' + ' rolling window. In either case, output at time index T uses the' + ' window [T, T + period]. In particular, whether using resample or' + ' rolling, output at matching times will be the same.' + ), ) PERIOD = flags.DEFINE_string( 'period', @@ -114,7 +119,9 @@ help='Add suffix "_mean" to variable name when computing the mean.', ) NUM_THREADS = flags.DEFINE_integer( - 'num_threads', None, help='Number of chunks to load in parallel per worker.' + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', ) TIME_DIM = flags.DEFINE_string( 'time_dim', 'time', help='Name for the time dimension to slice data on.' @@ -234,14 +241,16 @@ def resample_in_time_core( f'{delta_t=} between chunk times did not evenly divide {period=}' ) return getattr( - chunk.rolling({TIME_DIM.value: period // delta_t}), + chunk.rolling( + {TIME_DIM.value: period // delta_t}, center=False, min_periods=None + ), statistic, - )() + )(skipna=False) elif method == 'resample': return getattr( - chunk.resample({TIME_DIM.value: period}), + chunk.resample({TIME_DIM.value: period}, label='left'), statistic, - )() + )(skipna=False) else: raise ValueError(f'Unhandled {method=}') @@ -249,6 +258,8 @@ def resample_in_time_core( def main(argv: abc.Sequence[str]) -> None: ds, input_chunks = xbeam.open_zarr(INPUT_PATH.value) + period = pd.to_timedelta(PERIOD.value) + if TIME_START.value is not None or TIME_STOP.value is not None: ds = ds.sel({TIME_DIM.value: slice(TIME_START.value, TIME_STOP.value)}) @@ -267,8 +278,22 @@ def main(argv: abc.Sequence[str]) -> None: ) ds = ds[keep_vars] + # To ensure results at time T use data from [T, T + period], an offset needs + # to be added if the method is rolling. + # It would be wonderful if this was the default, or possible with appropriate + # kwargs in rolling, but alas... + if METHOD.value == 'rolling': + delta_ts = pd.to_timedelta(np.unique(np.diff(ds[TIME_DIM.value].data))) + if len(delta_ts) != 1: + raise ValueError( + f'Input data must have constant spacing. Found {delta_ts}' + ) + delta_t = delta_ts[0] + ds = ds.assign_coords( + {TIME_DIM.value: ds[TIME_DIM.value] - period + delta_t} + ) + # Make the template - period = pd.to_timedelta(PERIOD.value) if METHOD.value == 'resample': rsmp_times = resample_in_time_core( # All stats will give the same times, so use 'mean' arbitrarily. diff --git a/scripts/resample_in_time_test.py b/scripts/resample_in_time_test.py index 3429744..6010f53 100644 --- a/scripts/resample_in_time_test.py +++ b/scripts/resample_in_time_test.py @@ -27,7 +27,7 @@ class ResampleInTimeTest(parameterized.TestCase): - def test_demonstrating_returned_times_for_resample(self): + def test_demonstrating_resample_and_rolling_are_aligned(self): # times = 10 days, starting at Jan 1 times = pd.DatetimeIndex( [ @@ -43,7 +43,13 @@ def test_demonstrating_returned_times_for_resample(self): '2023-01-10', ] ) - temperatures = np.arange(len(times)) + temperatures = np.arange(len(times)).astype(float) + + # NaN inserted to (i) verify skipna=False, and (ii) verify correct setting + # for min_periods. If e.g. min_periods=1, then NaN values get skipped so + # long as there is at least one non-NaN value! + temperatures[0] = np.nan + input_ds = xr.Dataset( { 'temperature': xr.DataArray( @@ -53,37 +59,128 @@ def test_demonstrating_returned_times_for_resample(self): ) input_path = self.create_tempdir('source').full_path - output_path = self.create_tempdir('destination').full_path - input_ds.to_zarr(input_path) + # Get resampled output + resample_output_path = self.create_tempdir('resample').full_path with flagsaver.as_parsed( input_path=input_path, - output_path=output_path, + output_path=resample_output_path, method='resample', - period='1w', + period='3d', mean_vars='ALL', runner='DirectRunner', ): resample_in_time.main([]) + resample, unused_output_chunks = xarray_beam.open_zarr(resample_output_path) - output_ds, unused_output_chunks = xarray_beam.open_zarr(output_path) + # Show that the output at time T uses data from the window [T, T + period] np.testing.assert_array_equal( - output_ds.time, - # The first output time is the first input time - # The second output time is the first + 1w - np.array( - ['2023-01-01T00:00:00.000000000', '2023-01-08T00:00:00.000000000'], - dtype='datetime64[ns]', + pd.to_datetime(resample.time), + pd.DatetimeIndex( + ['2023-01-01', '2023-01-04', '2023-01-07', '2023-01-10'] ), ) np.testing.assert_array_equal( - output_ds.temperature.data, - # The first temperature is the average of the first 7 times - # The second temperature is the average of the remaining times (of which - # there are only 3). - [np.mean(temperatures[:7]), np.mean(temperatures[7:14])], + resample.temperature.data, + [ + np.mean(temperatures[:3]), # Will be NaN + np.mean(temperatures[3:6]), + np.mean(temperatures[6:9]), + np.mean(temperatures[9:12]), + ], + ) + + # Get rolled output + rolling_output_path = self.create_tempdir('rolling').full_path + with flagsaver.as_parsed( + input_path=input_path, + output_path=rolling_output_path, + method='rolling', + period='3d', + mean_vars='ALL', + runner='DirectRunner', + ): + resample_in_time.main([]) + rolling, unused_output_chunks = xarray_beam.open_zarr(rolling_output_path) + + common_times = pd.DatetimeIndex(['2023-01-01', '2023-01-04', '2023-01-07']) + xr.testing.assert_equal( + resample.sel(time=common_times), + rolling.sel(time=common_times), + ) + + @parameterized.parameters( + (20, '3d', None), + (21, '3d', None), + (21, '8d', None), + (5, '1d', None), + (20, '3d', [0, 4, 8]), + (21, '3d', [20]), + (21, '8d', [15]), + ) + def test_demonstrating_resample_and_rolling_are_aligned_many_combinations( + self, + n_times, + period, + nan_locations, + ): + # Less readable than test_demonstrating_resample_and_rolling_are_aligned, + # but these sorts of automated checks ensure we didn't miss an edge case + # (there are many!!!!) + times = pd.date_range('2010', periods=n_times) + temperatures = np.random.RandomState(802701).rand(n_times) + + for i in nan_locations or []: + temperatures[i] = np.nan + + input_ds = xr.Dataset( + { + 'temperature': xr.DataArray( + temperatures, coords=[times], dims=['time'] + ) + } + ) + + input_path = self.create_tempdir('source').full_path + input_ds.to_zarr(input_path) + + # Get resampled output + resample_output_path = self.create_tempdir('resample').full_path + with flagsaver.as_parsed( + input_path=input_path, + output_path=resample_output_path, + method='resample', + period=period, + mean_vars='ALL', + runner='DirectRunner', + ): + resample_in_time.main([]) + resample, unused_output_chunks = xarray_beam.open_zarr(resample_output_path) + + # Get rolled output + rolling_output_path = self.create_tempdir('rolling').full_path + with flagsaver.as_parsed( + input_path=input_path, + output_path=rolling_output_path, + method='rolling', + period=period, + mean_vars='ALL', + runner='DirectRunner', + ): + resample_in_time.main([]) + rolling, unused_output_chunks = xarray_beam.open_zarr(rolling_output_path) + + common_times = pd.to_datetime(resample.time.data).intersection( + rolling.time.data + ) + + # At most, one time is lost if the period doesn't evenly divide n_times. + self.assertGreaterEqual(len(common_times), len(resample.time) - 1) + xr.testing.assert_equal( + resample.sel(time=common_times), + rolling.sel(time=common_times), ) @parameterized.named_parameters( @@ -190,6 +287,12 @@ def test_resample_time(self, method, add_mean_suffix, period): // pd.to_timedelta(input_time_resolution) ).mean() ) + # Enact the time offsetting needed to align resample and rolling. + expected_mean = expected_mean.assign_coords( + time=expected_mean.time + - pd.to_timedelta(period) + + pd.to_timedelta(input_time_resolution) + ) else: raise ValueError(f'Unhandled {method=}') @@ -337,6 +440,12 @@ def test_resample_prediction_timedelta(self, method, add_mean_suffix, period): // pd.to_timedelta(input_time_resolution) ).mean() ) + # Enact the time offsetting needed to align resample and rolling. + expected_mean = expected_mean.assign_coords( + prediction_timedelta=expected_mean.prediction_timedelta + - pd.to_timedelta(period) + + pd.to_timedelta(input_time_resolution) + ) else: raise ValueError(f'Unhandled {method=}') diff --git a/scripts/slice_dataset.py b/scripts/slice_dataset.py new file mode 100644 index 0000000..9c7b18a --- /dev/null +++ b/scripts/slice_dataset.py @@ -0,0 +1,192 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""CLI to slice a Zarr file containing a xarray.Dataset. + + +Example Usage: + + ``` + export BUCKET=my-bucket + export PROJECT=my-project + export REGION=us-central1 + + python scripts/resample_in_time.py \ + --input_path=gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_with_poles_conservative.zarr \ + --output_path=gs://$BUCKET/datasets/era5/$USER/2020-2021-weekly-average-temperature.zarr \ + --runner=DataflowRunner \ + --sel="prediction_timedelta_stop=15 days,latitude_start=-33.33,latitude_stop=33.33" \ + --isel="longitude_start=0,longitude_stop=180,longitude_step=40" \ + --keep_variables=geopotential,temperature \ + -- \ + --project=$PROJECT \ + --temp_location=gs://$BUCKET/tmp/ \ + --setup_file=./setup.py \ + --requirements_file=./scripts/dataflow-requirements.txt \ + --job_name=slice-dataset-$USER + ``` +""" + +from collections import abc +import re + +from absl import app +from absl import flags +import apache_beam as beam +from weatherbench2 import flag_utils +import xarray_beam as xbeam + + +# Command line arguments +INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path.') +OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path.') + +SEL = flag_utils.DEFINE_dim_value_pairs( + 'sel', + '', + help=( + 'Selection criteria, to pass to xarray.Dataset.sel. Passed as key=value' + ' pairs, with key = VARNAME_{start,stop,step}' + ), +) + +ISEL = flag_utils.DEFINE_dim_integer_pairs( + 'isel', + '', + help=( + 'Selection criteria, to pass to xarray.Dataset.isel. Passed as' + ' key=value pairs, with key = VARNAME_{start,stop,step}' + ), +) + +DROP_VARIABLES = flags.DEFINE_list( + 'drop_variables', + None, + help=( + 'Comma delimited list of variables to drop. If empty, drop no' + ' variables.' + ), +) + +KEEP_VARIABLES = flags.DEFINE_list( + 'keep_variables', + None, + help=( + 'Comma delimited list of variables to keep. If empty, use' + ' --drop_variables to determine which variables to keep' + ), +) + +OUTPUT_CHUNKS = flag_utils.DEFINE_chunks( + 'output_chunks', '', help='Chunk sizes overriding input chunks.' +) + +RUNNER = flags.DEFINE_string( + 'runner', None, help='Beam runner. Use DirectRunner for local execution.' +) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) + + +def _get_selections( + isel_flag_value: dict[str, int], + sel_flag_value: dict[str, flag_utils.DimValueType], +) -> tuple[dict[str, slice], dict[str, slice]]: + """Gets dictionaries for `xr.isel` and `xr.sel`.""" + isel_parts = {} + sel_parts = {} + for parts_dict, flag_value in [ + (isel_parts, isel_flag_value), + (sel_parts, sel_flag_value), + ]: + for k, v in flag_value.items(): + match = re.search(r'^(.*)_(start|stop|step)$', k) + if not match: + raise ValueError(f'Flag {k} did not end in _(start|stop|step)') + dim, placement = match.groups() + if dim not in parts_dict: + parts_dict[dim] = [None, None, None] + if placement == 'start': + parts_dict[dim][0] = v + elif placement == 'stop': + parts_dict[dim][1] = v + else: + parts_dict[dim][2] = v + + overlap = set(isel_parts).intersection(sel_parts) + if overlap: + raise ValueError( + f'--isel {isel_flag_value} and --sel {sel_flag_value} overlapped for' + f' variables {overlap}' + ) + isel = {k: slice(*v) for k, v in isel_parts.items()} + sel = {k: slice(*v) for k, v in sel_parts.items()} + return isel, sel + + +def main(argv: abc.Sequence[str]) -> None: + + ds, input_chunks = xbeam.open_zarr(INPUT_PATH.value) + + if DROP_VARIABLES.value: + ds = ds[[v for v in ds if v not in DROP_VARIABLES.value]] + elif KEEP_VARIABLES.value: + ds = ds[KEEP_VARIABLES.value] + + isel, sel = _get_selections(ISEL.value, SEL.value) + if isel: + ds = ds.isel(isel) + if sel: + ds = ds.sel(sel) + + template = xbeam.make_template(ds) + + output_chunks = {k: v for k, v in input_chunks.items()} # Copy + for k in output_chunks: + if k in OUTPUT_CHUNKS.value: + output_chunks[k] = OUTPUT_CHUNKS.value[k] + else: + output_chunks[k] = min(output_chunks[k], ds.sizes[k]) + + itemsize = max(var.dtype.itemsize for var in template.values()) + + with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: + # Read, rechunk, write + unused_pcoll = ( + root + | xbeam.DatasetToChunks( + ds, input_chunks, split_vars=True, num_threads=NUM_THREADS.value + ) + | xbeam.Rechunk( # pytype: disable=wrong-arg-types + ds.sizes, + input_chunks, + output_chunks, + itemsize=itemsize, + ) + | xbeam.ChunksToZarr( + OUTPUT_PATH.value, + template=template, + zarr_chunks=output_chunks, + num_threads=NUM_THREADS.value, + ) + ) + + +if __name__ == '__main__': + flags.mark_flags_as_required(['input_path', 'output_path']) + flags.mark_flags_as_mutual_exclusive(['keep_variables', 'drop_variables']) + app.run(main) diff --git a/scripts/slice_dataset_test.py b/scripts/slice_dataset_test.py new file mode 100644 index 0000000..e1fa7d8 --- /dev/null +++ b/scripts/slice_dataset_test.py @@ -0,0 +1,163 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from absl.testing import absltest +from absl.testing import flagsaver +from absl.testing import parameterized +from weatherbench2 import schema +from weatherbench2 import utils +import xarray as xr +import xarray_beam + +from . import slice_dataset + + +class GetSelectionsTest(parameterized.TestCase): + + def test_valid_combinations(self): + isel, sel = slice_dataset._get_selections( + isel_flag_value={ + 'X_start': 0, + 'X_stop': 10, + 'X_step': 2, + 'Y_stop': 4, + 'Z_start': 1, + 'W_step': 2, + }, + sel_flag_value={ + 'A_start': '1 day', + 'A_stop': '10 days', + 'A_step': '2 days', + }, + ) + expected_isel = { + 'X': slice(0, 10, 2), + 'Y': slice(None, 4, None), + 'Z': slice(1, None, None), + 'W': slice(None, None, 2), + } + expected_sel = { + 'A': slice('1 day', '10 days', '2 days'), + } + self.assertEqual(expected_isel, isel) + self.assertEqual(expected_sel, sel) + + def test_invalid_placement_raises(self): + with self.subTest('Not ending in (start|stop|step) raises'): + with self.assertRaisesRegex(ValueError, 'did not end in'): + slice_dataset._get_selections( + isel_flag_value={ + 'X_start': 0, + 'X_stop': 10, + 'X_bad': 2, + }, + sel_flag_value={}, + ) + + with self.subTest('Not ending in (start|stop|step) raises 2'): + with self.assertRaisesRegex(ValueError, 'did not end in'): + slice_dataset._get_selections( + isel_flag_value={ + 'X_start': 0, + 'X_stop': 10, + 'X_step_and_more': 2, + }, + sel_flag_value={}, + ) + + with self.subTest('Not ending in (start|stop|step) raises 2'): + with self.assertRaisesRegex(ValueError, 'did not end in'): + slice_dataset._get_selections( + isel_flag_value={ + 'X_start': 0, + 'X_stop': 10, + 'X_step_': 2, + }, + sel_flag_value={}, + ) + + def test_overlapping_dims_raise(self): + with self.assertRaisesRegex(ValueError, 'overlapped'): + slice_dataset._get_selections( + isel_flag_value={ + 'X_start': 0, + 'X_stop': 10, + }, + sel_flag_value={ + 'X_step': 2, + }, + ) + + +class SliceDatasetTest(parameterized.TestCase): + + def test_simple_slicing(self): + input_ds = utils.random_like( + schema.mock_truth_data( + variables_2d=[], + variables_3d=['temperature', 'geopotential', 'should_drop'], + # time_start/stop for the raw data is wider than the times for + # resampled data. + time_start='2021-01-01', + time_stop='2022-01-01', + spatial_resolution_in_degrees=30.0, + time_resolution='1d', + ) + ) + + # Make variables different so we test that variables are handled + # individually. + input_ds = input_ds.assign({'geopotential': input_ds.geopotential + 10}) + + input_path = self.create_tempdir('source').full_path + output_path = self.create_tempdir('destination').full_path + + input_chunks = {'time': 40, 'longitude': 6, 'latitude': 5, 'level': 3} + input_ds.chunk(input_chunks).to_zarr(input_path) + + with flagsaver.as_parsed( + input_path=input_path, + output_path=output_path, + output_chunks='level=1', + sel=( + # Note that time_step is an integer, since pandas requires this. + 'time_start=2021-02-01,time_stop=2021-04-01,time_step=5,' + 'longitude_step=60' + ), + isel='latitude_stop=5', + drop_variables='should_drop', + runner='DirectRunner', + ): + slice_dataset.main([]) + + output_ds, output_chunks = xarray_beam.open_zarr(output_path) + expected_output_ds = input_ds.sel( + time=slice('2021-02-01', '2021-04-01', 5), + longitude=slice(None, None, 60), + ).isel(latitude=slice(5))[['temperature', 'geopotential']] + xr.testing.assert_equal(output_ds, expected_output_ds) + + expected_output_chunks = { + 'time': min(input_chunks['time'], output_ds.sizes['time']), + 'longitude': min( + input_chunks['longitude'], output_ds.sizes['longitude'] + ), + 'latitude': min(input_chunks['latitude'], output_ds.sizes['latitude']), + 'level': 1, # level was explicitly specified + } + self.assertEqual(expected_output_chunks, output_chunks) + + +if __name__ == '__main__': + absltest.main() diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index f797209..c5c5e7b 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -509,6 +509,7 @@ class _SaveOutputs(beam.PTransform): eval_name: str data_config: config.Data output_format: str + num_threads: Optional[int] = None def _write_netcdf(self, datasets: list[xr.Dataset]) -> xr.Dataset: combined = xr.combine_by_coords(datasets) @@ -529,7 +530,9 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection: output_path = _get_output_path( self.data_config, self.eval_name, self.output_format ) - return pcoll | xbeam.ChunksToZarr(output_path) + return pcoll | xbeam.ChunksToZarr( + output_path, num_threads=self.num_threads + ) else: raise ValueError(f'unrecogonized data format: {self.output_format}') @@ -551,6 +554,7 @@ class _EvaluateAllMetrics(beam.PTransform): data_config: config.Data input_chunks: abc.Mapping[str, int] fanout: Optional[int] = None + num_threads: Optional[int] = None def _evaluate_chunk( self, @@ -662,12 +666,14 @@ def _evaluate( forecast, self.input_chunks, split_vars=False, + num_threads=self.num_threads, ) | beam.MapTuple(self._sel_corresponding_truth_chunk, truth=truth) else: forecast_pipeline = xbeam.DatasetToChunks( [forecast, truth], self.input_chunks, split_vars=False, + num_threads=self.num_threads, ) if self.eval_config.evaluate_climatology: @@ -723,6 +729,7 @@ def evaluate_with_beam( input_chunks: abc.Mapping[str, int], runner: str, fanout: Optional[int] = None, + num_threads: Optional[int] = None, argv: Optional[list[str]] = None, ) -> None: """Run evaluation with a Beam pipeline. @@ -750,6 +757,7 @@ def evaluate_with_beam( input_chunks: Chunking of input datasets. runner: Beam runner. fanout: Beam CombineFn fanout. + num_threads: Number of threads to use for reading/writing data. argv: Other arguments to pass into the Beam pipeline. """ @@ -760,8 +768,18 @@ def evaluate_with_beam( root | f'evaluate_{eval_name}' >> _EvaluateAllMetrics( - eval_name, eval_config, data_config, input_chunks, fanout=fanout + eval_name, + eval_config, + data_config, + input_chunks, + fanout=fanout, + num_threads=num_threads, ) | f'save_{eval_name}' - >> _SaveOutputs(eval_name, data_config, eval_config.output_format) + >> _SaveOutputs( + eval_name, + data_config, + eval_config.output_format, + num_threads=num_threads, + ) ) diff --git a/weatherbench2/flag_utils.py b/weatherbench2/flag_utils.py index 58e43c4..b5e6673 100644 --- a/weatherbench2/flag_utils.py +++ b/weatherbench2/flag_utils.py @@ -14,10 +14,12 @@ # ============================================================================== """WeatherBench2 utilities for working with command line flags.""" import re -from typing import Any +from typing import Any, Union from absl import flags +DimValueType = Union[int, float, str] + def _chunks_string_is_valid(chunks_string: str) -> bool: return re.fullmatch(r'(\w+=-?\d+(,\w+=-?\d+)*)?', chunks_string) is not None @@ -52,8 +54,8 @@ def flag_type(self) -> str: return 'dict[str, int]' -class _ChunksSerializer(flags.ArgumentSerializer): - """Serializer for Xarray-Beam chunks flags.""" +class _DimValuePairSerializer(flags.ArgumentSerializer): + """Serializer for dim=value pairs.""" def serialize(self, value: dict[str, int]) -> str: return ','.join(f'{k}={v}' for k, v in value.items()) @@ -67,7 +69,67 @@ def DEFINE_chunks( # pylint: disable=invalid-name ): """Define a flag for defining Xarray-Beam chunks.""" parser = _ChunksParser() - serializer = _ChunksSerializer() + serializer = _DimValuePairSerializer() + return flags.DEFINE( + parser, name, default, help, serializer=serializer, **kwargs + ) + + +# Key/value pairs of the form dimension=integer have the same requirements as +# chunks. +DEFINE_dim_integer_pairs = DEFINE_chunks + + +class _DimValuePairParser(flags.ArgumentParser): + """Parser for dim=value pairs.""" + + syntactic_help: str = ( + 'comma separate list of dim=value pairs, e.g.,' + '"time=0 days,longitude=100"' + ) + + def parse(self, argument: str) -> dict[str, DimValueType]: + return _parse_dim_value_pairs(argument) + + def flag_type(self) -> str: + """Returns a string representing the type of the flag.""" + return 'dict[str, int | float | str]' + + +def _get_dim_value(value_string: str) -> DimValueType: + """Tries returning int then float, fallback to string.""" + try: + return int(value_string) + except ValueError: + pass + try: + return float(value_string) + except ValueError: + pass + return value_string + + +def _parse_dim_value_pairs( + dim_value_string: str, +) -> dict[str, DimValueType]: + """Parse a chunks string into a dict.""" + pairs = {} + if dim_value_string: + for entry in dim_value_string.split(','): + key, value = entry.split('=') + pairs[key] = _get_dim_value(value) + return pairs + + +def DEFINE_dim_value_pairs( # pylint: disable=invalid-name + name: str, + default: str, + help: str, # pylint: disable=redefined-builtin + **kwargs: Any, +): + """Flag for defining key=value pairs, string key, value a str/int/float.""" + parser = _DimValuePairParser() + serializer = _DimValuePairSerializer() return flags.DEFINE( parser, name, default, help, serializer=serializer, **kwargs )