Skip to content

Commit

Permalink
BUGFIX: Remove hard-coded time dimension in resample_in_time.py. Pr…
Browse files Browse the repository at this point in the history
…eviously, we used `TIME_DIM.value` in some places, and `time` in others.

PiperOrigin-RevId: 628121008
  • Loading branch information
langmore authored and Weatherbench2 authors committed Apr 25, 2024
1 parent c6c4a1a commit 1d3cb4d
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 11 deletions.
10 changes: 5 additions & 5 deletions scripts/resample_in_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def resample_in_time_chunk(


def resample_in_time_core(
chunk: xr.Dataset,
chunk: t.Union[xr.Dataset, xr.DataArray],
method: str,
period: pd.Timedelta,
statistic: str,
Expand Down Expand Up @@ -272,13 +272,13 @@ def main(argv: abc.Sequence[str]) -> None:
if METHOD.value == 'resample':
rsmp_times = resample_in_time_core(
# All stats will give the same times, so use 'mean' arbitrarily.
ds.time,
ds[TIME_DIM.value],
METHOD.value,
period,
statistic='mean',
).time
)[TIME_DIM.value]
else:
rsmp_times = ds.time
rsmp_times = ds[TIME_DIM.value]
assert isinstance(ds, xr.Dataset) # To satisfy pytype.
rsmp_template = (
xbeam.make_template(ds)
Expand Down Expand Up @@ -306,7 +306,7 @@ def main(argv: abc.Sequence[str]) -> None:
working_chunks.update(WORKING_CHUNKS.value)
if TIME_DIM.value in working_chunks:
raise ValueError('cannot include time working chunks')
working_chunks[TIME_DIM.value] = len(ds.time)
working_chunks[TIME_DIM.value] = len(ds[TIME_DIM.value])
output_chunks = input_chunks.copy()
output_chunks[TIME_DIM.value] = min(
len(rsmp_times), output_chunks[TIME_DIM.value]
Expand Down
166 changes: 160 additions & 6 deletions scripts/resample_in_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_demonstrating_returned_times_for_resample(self):
testcase_name='Resample_NoSuffix_5d',
method='resample',
add_mean_suffix=False,
period='1d',
period='5d',
),
dict(
testcase_name='Resample_YesSuffix_1w',
Expand Down Expand Up @@ -124,7 +124,10 @@ def test_demonstrating_returned_times_for_resample(self):
period='1d',
),
)
def test_resample(self, method, add_mean_suffix, period):
def test_resample_time(self, method, add_mean_suffix, period):
# Make sure slice(start, stop, period) doesn't give you a singleton, since
# then, for this singleton, the resampled mean/min/max will all be equal,
# and the test will fail.
time_start = '2021-02-01'
time_stop = '2021-04-01'
mean_vars = ['temperature', 'geopotential']
Expand Down Expand Up @@ -159,9 +162,9 @@ def test_resample(self, method, add_mean_suffix, period):
output_path=output_path,
method=method,
period=period,
mean_vars='temperature,geopotential',
min_vars='temperature,geopotential',
max_vars='temperature',
mean_vars=','.join(mean_vars),
min_vars=','.join(min_vars),
max_vars=','.join(max_vars),
add_mean_suffix=str(add_mean_suffix),
time_start=time_start,
time_stop=time_stop,
Expand All @@ -182,7 +185,10 @@ def test_resample(self, method, add_mean_suffix, period):
expected_mean = (
input_ds.sel(time=slice(time_start, time_stop))
# input_ds timedelta is 1 day.
.rolling(time=pd.to_timedelta(period) // pd.to_timedelta('1d')).mean()
.rolling(
time=pd.to_timedelta(period)
// pd.to_timedelta(input_time_resolution)
).mean()
)
else:
raise ValueError(f'Unhandled {method=}')
Expand Down Expand Up @@ -215,6 +221,154 @@ def test_resample(self, method, add_mean_suffix, period):

self.assertCountEqual(expected_varnames, output_ds.data_vars)

@parameterized.named_parameters(
dict(
testcase_name='Resample_NoSuffix_5d',
method='resample',
add_mean_suffix=False,
period='5d',
),
dict(
testcase_name='Resample_YesSuffix_1w',
method='resample',
add_mean_suffix=True,
period='1w',
),
dict(
testcase_name='Resample_YesSuffix_1d',
method='resample',
add_mean_suffix=True,
period='1d',
),
dict(
testcase_name='Roll_YesSuffix_1w',
method='rolling',
add_mean_suffix=True,
period='1w',
),
dict(
testcase_name='Roll_NoSuffix_30d',
method='rolling',
add_mean_suffix=False,
period='30d',
),
dict(
testcase_name='Roll_YesSuffix_1d',
method='rolling',
add_mean_suffix=True,
period='1d',
),
)
def test_resample_prediction_timedelta(self, method, add_mean_suffix, period):
# Make sure slice(start, stop, period) doesn't give you a singleton, since
# then, for this singleton, the resampled mean/min/max will all be equal,
# and the test will fail.
timedelta_start = '0 day'
timedelta_stop = '9 days'
mean_vars = ['temperature', 'geopotential']
min_vars = ['temperature', 'geopotential']
max_vars = ['temperature']
input_time_resolution = '1d'

input_ds = utils.random_like(
schema.mock_forecast_data(
lead_start='0 day',
lead_stop='15 days',
lead_resolution='1 day',
variables_3d=['temperature', 'geopotential', 'should_drop'],
time_start='2021-01-01',
time_stop='2021-01-10',
spatial_resolution_in_degrees=30.0,
time_resolution=input_time_resolution,
)
)

# Make variables different so we test that variables are handled
# individually.
input_ds = input_ds.assign({'geopotential': input_ds.geopotential + 10})

input_path = self.create_tempdir('source').full_path
output_path = self.create_tempdir('destination').full_path

input_chunks = {
'time': 9,
'prediction_timedelta': 5,
'longitude': 6,
'latitude': 5,
'level': 3,
}
input_ds.chunk(input_chunks).to_zarr(input_path)

with flagsaver.as_parsed(
input_path=input_path,
output_path=output_path,
method=method,
period=period,
mean_vars=','.join(mean_vars),
min_vars=','.join(min_vars),
max_vars=','.join(max_vars),
add_mean_suffix=str(add_mean_suffix),
time_start=timedelta_start,
time_stop=timedelta_stop,
working_chunks='level=1',
time_dim='prediction_timedelta',
runner='DirectRunner',
):
resample_in_time.main([])

output_ds, output_chunks = xarray_beam.open_zarr(output_path)

if method == 'resample':
expected_mean = (
input_ds.sel(
prediction_timedelta=slice(timedelta_start, timedelta_stop)
)
.resample(prediction_timedelta=pd.to_timedelta(period))
.mean()
)
elif method == 'rolling':
expected_mean = (
input_ds.sel(
prediction_timedelta=slice(timedelta_start, timedelta_stop)
)
# input_ds timedelta is 1 day.
.rolling(
prediction_timedelta=pd.to_timedelta(period)
// pd.to_timedelta(input_time_resolution)
).mean()
)
else:
raise ValueError(f'Unhandled {method=}')

expected_chunks = input_chunks.copy()
if method == 'resample':
expected_chunks['prediction_timedelta'] = min(
len(expected_mean.prediction_timedelta),
expected_chunks['prediction_timedelta'],
)
self.assertEqual(expected_chunks, output_chunks)

expected_varnames = []

for k in mean_vars:
expected_varnames.append(k + '_mean' if add_mean_suffix else k)
xr.testing.assert_allclose(
expected_mean[k],
output_ds[k + '_mean' if add_mean_suffix else k],
)

for k in min_vars:
expected_varnames.append(k + '_min')
if period != input_time_resolution:
np.testing.assert_array_less(output_ds[k + '_min'], expected_mean[k])

for k in max_vars:
expected_varnames.append(k + '_max')
if period != input_time_resolution:
np.testing.assert_array_less(expected_mean[k], output_ds[k + '_max'])

self.assertCountEqual(expected_varnames, output_ds.data_vars)


if __name__ == '__main__':
absltest.main()

0 comments on commit 1d3cb4d

Please sign in to comment.