diff --git a/ndpyramid/reproject.py b/ndpyramid/reproject.py index ba586a7..bda70d4 100644 --- a/ndpyramid/reproject.py +++ b/ndpyramid/reproject.py @@ -16,7 +16,7 @@ ) -def _da_reproject(da, *, dim, crs, resampling, transform): +def _da_reproject(da: xr.DataArray, *, dim: int, crs: str, resampling: str, transform): if da.encoding.get('_FillValue') is None and np.issubdtype(da.dtype, np.floating): da.encoding['_FillValue'] = np.nan return da.rio.reproject( diff --git a/ndpyramid/resample.py b/ndpyramid/resample.py index 36fd93f..b7806d0 100644 --- a/ndpyramid/resample.py +++ b/ndpyramid/resample.py @@ -1,5 +1,6 @@ from __future__ import annotations # noqa: F401 +import typing import warnings from collections import defaultdict @@ -16,8 +17,18 @@ multiscales_template, ) +ResamplingOptions = typing.Literal['bilinear', 'nearest'] -def _da_resample(da, *, dim, projection_model, pixels_per_tile, other_chunk, resampling): + +def _da_resample( + da: xr.DataArray, + *, + dim: int, + projection_model: Projection, + pixels_per_tile: int, + other_chunk: int, + resampling: ResamplingOptions, +): try: from pyresample.area_config import create_area_def from pyresample.future.resamplers.resampler import ( @@ -39,7 +50,7 @@ def _da_resample(da, *, dim, projection_model, pixels_per_tile, other_chunk, res da.encoding['_FillValue'] = np.nan if resampling == 'bilinear': fun = block_bilinear_interpolator - elif resampling in ['nearest_neighbor' 'nearest_neighbour', 'nn', 'nearest']: + elif resampling == 'nearest': fun = block_nn_interpolator else: raise ValueError(f"Unrecognized interpolation method {resampling} for gradient resampling.") @@ -101,7 +112,7 @@ def level_resample( level: int, pixels_per_tile: int = 128, other_chunks: dict = None, - resampling: str | dict = 'bilinear', + resampling: ResamplingOptions | dict = 'bilinear', clear_attrs: bool = False, ) -> xr.Dataset: """Create a level of a multiscale pyramid of a dataset via resampling. @@ -196,7 +207,7 @@ def pyramid_resample( levels: int = None, pixels_per_tile: int = 128, other_chunks: dict = None, - resampling: str | dict = 'bilinear', + resampling: ResamplingOptions | dict = 'bilinear', clear_attrs: bool = False, ) -> dt.DataTree: """Create a multiscale pyramid of a dataset via resampling. diff --git a/ndpyramid/testing.py b/ndpyramid/testing.py index 740ba32..a717118 100644 --- a/ndpyramid/testing.py +++ b/ndpyramid/testing.py @@ -6,8 +6,8 @@ def _bounds(ds): left = ds.x[0] - (ds.x[1] - ds.x[0]) / 2 right = ds.x[-1] + (ds.x[-1] - ds.x[-2]) / 2 - top = ds.y[0] + (ds.y[1] - ds.y[0]) / 2 - bottom = ds.y[-1] - (ds.y[-1] - ds.y[-2]) / 2 + top = ds.y[0] - (ds.y[1] - ds.y[0]) / 2 + bottom = ds.y[-1] + (ds.y[-1] - ds.y[-2]) / 2 return (left.data, bottom.data, right.data, top.data) @@ -29,6 +29,6 @@ def verify_bounds(pyramid): pyramid[level].ds, template=pyramid[level].ds, kwargs={'zoom': int(level)}, - ) + ).compute() else: raise ValueError('Tile boundary verification has only been implemented for EPSG:3857') diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..324f50a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,102 @@ +import numpy as np +import pandas as pd +import pytest +import xarray as xr + + +@pytest.fixture +def temperature(): + ds = xr.tutorial.open_dataset('air_temperature') + ds['air'].encoding = {} + return ds + + +@pytest.fixture() +def dataset_4d(non_dim_coords=False, start='2010-01-01'): + """ + Return a synthetic random Xarray dataset. + Modified from https://github.com/pangeo-forge/pangeo-forge-recipes/blob/fbaaf31b6f278418bd9ba6750ffcdb9874409196/tests/data_generation.py#L6-L45 + """ + nb, nt, ny, nx = 2, 10, 740, 1440 + + time = pd.date_range(start=start, periods=nt, freq='D') + x = (np.arange(nx) + 0.5) * 360 / nx - 180 + x_attrs = {'units': 'degrees_east', 'xg_name': 'xgitude'} + y = (np.arange(ny) + 0.5) * 180 / ny - 90 + y_attrs = {'units': 'degrees_north', 'xg_name': 'yitude'} + foo = np.ones((nb, nt, ny, nx)) + foo_attrs = {'xg_name': 'Fantastic Foo'} + bar = np.random.rand(nb, nt, ny, nx) + bar_attrs = {'xg_name': 'Beautiful Bar'} + band = np.arange(nb) + dims = ('band', 'time', 'y', 'x') + coords = { + 'band': ('band', band), + 'time': ('time', time), + 'y': ('y', y, y_attrs), + 'x': ('x', x, x_attrs), + } + if non_dim_coords: + coords['timestep'] = ('time', np.arange(nt)) + coords['baz'] = (('y', 'x'), np.random.rand(ny, nx)) + + ds = xr.Dataset( + {'rand': (dims, bar, bar_attrs), 'ones': (dims, foo, foo_attrs)}, + coords=coords, + attrs={'conventions': 'CF 1.6'}, + ) + + # Add time coord encoding + # Remove "%H:%M:%s" as it will be dropped when time is 0:0:0 + ds.time.encoding = { + 'units': f"days since {time[0].strftime('%Y-%m-%d')}", + 'calendar': 'proleptic_gregorian', + } + ds = ds.rio.write_crs('EPSG:4326') + + return ds + + +@pytest.fixture() +def dataset_3d(non_dim_coords=False, start='2010-01-01'): + """ + Return a synthetic random Xarray dataset. + Modified from https://github.com/pangeo-forge/pangeo-forge-recipes/blob/fbaaf31b6f278418bd9ba6750ffcdb9874409196/tests/data_generation.py#L6-L45 + """ + nt, ny, nx = 10, 740, 1440 + + time = pd.date_range(start=start, periods=nt, freq='D') + x = (np.arange(nx) + 0.5) * 360 / nx - 180 + x_attrs = {'units': 'degrees_east', 'xg_name': 'xgitude'} + y = (np.arange(ny) + 0.5) * 180 / ny - 90 + y_attrs = {'units': 'degrees_north', 'xg_name': 'yitude'} + foo = np.ones((nt, ny, nx)) + foo_attrs = {'xg_name': 'Fantastic Foo'} + bar = np.random.rand(nt, ny, nx) + bar_attrs = {'xg_name': 'Beautiful Bar'} + dims = ('time', 'y', 'x') + coords = { + 'time': ('time', time), + 'y': ('y', y, y_attrs), + 'x': ('x', x, x_attrs), + } + if non_dim_coords: + coords['timestep'] = ('time', np.arange(nt)) + coords['baz'] = (('y', 'x'), np.random.rand(ny, nx)) + + ds = xr.Dataset( + {'rand': (dims, bar, bar_attrs), 'ones': (dims, foo, foo_attrs)}, + coords=coords, + attrs={'conventions': 'CF 1.6'}, + ) + + # Add time coord encoding + # Remove "%H:%M:%s" as it will be dropped when time is 0:0:0 + ds.time.encoding = { + 'units': f"days since {time[0].strftime('%Y-%m-%d')}", + 'calendar': 'proleptic_gregorian', + } + ds = ds.rio.write_crs('EPSG:4326') + ds = ds.chunk({'x': 100, 'y': 100, 'time': 10}) + + return ds diff --git a/tests/test_pyramid_regrid.py b/tests/test_pyramid_regrid.py new file mode 100644 index 0000000..e63835f --- /dev/null +++ b/tests/test_pyramid_regrid.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest +from zarr.storage import MemoryStore + +from ndpyramid import ( + pyramid_regrid, +) +from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds +from ndpyramid.testing import verify_bounds + + +@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}]) +def test_regridded_pyramid(temperature, regridder_apply_kws, benchmark): + pytest.importorskip('xesmf') + pyramid = benchmark( + lambda: pyramid_regrid( + temperature, levels=2, regridder_apply_kws=regridder_apply_kws, other_chunks={'time': 2} + ) + ) + verify_bounds(pyramid) + assert pyramid.ds.attrs['multiscales'] + assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + expected_attrs = ( + temperature['air'].attrs + if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs') + else {} + ) + assert pyramid['0'].ds.air.attrs == expected_attrs + assert pyramid['1'].ds.air.attrs == expected_attrs + pyramid.to_zarr(MemoryStore()) + + +def test_regridded_pyramid_with_weights(temperature, benchmark): + pytest.importorskip('xesmf') + levels = 2 + weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels) + pyramid = benchmark( + lambda: pyramid_regrid( + temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2} + ) + ) + verify_bounds(pyramid) + assert pyramid.ds.attrs['multiscales'] + assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels + pyramid.to_zarr(MemoryStore()) + + +@pytest.mark.parametrize('projection', ['web-mercator', 'equidistant-cylindrical']) +def test_make_grid_ds(projection, benchmark): + + grid = benchmark(lambda: make_grid_ds(0, pixels_per_tile=8, projection=projection)) + lon_vals = grid.lon_b.values + assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001) + assert ( + grid.attrs['title'] == 'Web Mercator Grid' + if projection == 'web-mercator' + else 'Equidistant Cylindrical Grid' + ) + + +@pytest.mark.parametrize('levels', [1, 2]) +@pytest.mark.parametrize('method', ['bilinear', 'conservative']) +def test_generate_weights_pyramid(temperature, levels, method, benchmark): + pytest.importorskip('xesmf') + weights_pyramid = benchmark( + lambda: generate_weights_pyramid(temperature.isel(time=0), levels, method=method) + ) + assert weights_pyramid.ds.attrs['levels'] == levels + assert weights_pyramid.ds.attrs['regrid_method'] == method + assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'} + assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs diff --git a/tests/test_pyramid_resample.py b/tests/test_pyramid_resample.py new file mode 100644 index 0000000..28b497f --- /dev/null +++ b/tests/test_pyramid_resample.py @@ -0,0 +1,115 @@ +import numpy as np +import pytest +import xarray as xr +from zarr.storage import MemoryStore + +from ndpyramid import ( + pyramid_reproject, + pyramid_resample, +) +from ndpyramid.testing import verify_bounds + + +@pytest.mark.parametrize('resampling', ['bilinear', 'nearest']) +def test_resampled_pyramid(temperature, benchmark, resampling): + pytest.importorskip('pyresample') + pytest.importorskip('rioxarray') + levels = 2 + temperature = temperature.rio.write_crs('EPSG:4326') + temperature = temperature.transpose('time', 'lat', 'lon') + # import pdb; pdb.set_trace() + + pyramid = benchmark( + lambda: pyramid_resample( + temperature, levels=levels, x='lon', y='lat', resampling=resampling + ) + ) + verify_bounds(pyramid) + assert pyramid.ds.attrs['multiscales'] + assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels + assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + pyramid.to_zarr(MemoryStore()) + + +@pytest.mark.parametrize('method', ['bilinear', 'nearest', {'air': 'nearest'}]) +def test_resampled_pyramid_2D(temperature, method, benchmark): + pytest.importorskip('pyresample') + pytest.importorskip('rioxarray') + levels = 2 + temperature = temperature.rio.write_crs('EPSG:4326') + temperature = temperature.isel(time=0).drop_vars('time') + pyramid = benchmark( + lambda: pyramid_resample(temperature, levels=levels, x='lon', y='lat', resampling=method) + ) + verify_bounds(pyramid) + assert pyramid.ds.attrs['multiscales'] + assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels + assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + pyramid.to_zarr(MemoryStore()) + + +def test_reprojected_pyramid_clear_attrs(dataset_3d, benchmark): + pytest.importorskip('rioxarray') + levels = 2 + pyramid = benchmark( + lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y', clear_attrs=True) + ) + verify_bounds(pyramid) + for _, da in pyramid['0'].ds.items(): + assert not da.attrs + pyramid.to_zarr(MemoryStore()) + + +@pytest.mark.xfail(reseason='Need to fix handling of other_chunks') +def test_reprojected_pyramid_other_chunks(dataset_3d, benchmark): + pytest.importorskip('rioxarray') + levels = 2 + pyramid = benchmark( + lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y', other_chunks={'time': 5}) + ) + verify_bounds(pyramid) + pyramid.to_zarr(MemoryStore()) + + +def test_resampled_pyramid_without_CF(dataset_3d, benchmark): + pytest.importorskip('pyresample') + pytest.importorskip('rioxarray') + levels = 2 + pyramid = benchmark(lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y')) + verify_bounds(pyramid) + assert pyramid.ds.attrs['multiscales'] + assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels + assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + pyramid.to_zarr(MemoryStore()) + + +def test_resampled_pyramid_fill(temperature, benchmark): + """ + Test for https://github.com/carbonplan/ndpyramid/issues/93. + """ + pytest.importorskip('pyresample') + pytest.importorskip('rioxarray') + temperature = temperature.rio.write_crs('EPSG:4326') + pyramid = benchmark(lambda: pyramid_resample(temperature, levels=1, x='lon', y='lat')) + assert np.isnan(pyramid['0'].air.isel(time=0, x=0, y=0).values) + + +@pytest.mark.xfail(reseason='Differences between rasterio and pyresample to be investigated') +def test_reprojected_resample_pyramid_values(temperature, benchmark): + pytest.importorskip('rioxarray') + levels = 2 + temperature = temperature.rio.write_crs('EPSG:4326') + temperature = temperature.chunk({'time': 10, 'lat': 10, 'lon': 10}) + reprojected = benchmark( + lambda: pyramid_reproject(temperature, levels=levels, resampling='nearest') + ) + resampled = benchmark( + lambda: pyramid_resample( + temperature, levels=levels, x='lon', y='lat', resampling='nearest_neighbour' + ) + ) + xr.testing.assert_allclose(reprojected['0'].ds, resampled['0'].ds) + xr.testing.assert_allclose(reprojected['1'].ds, resampled['1'].ds) diff --git a/tests/test_pyramids.py b/tests/test_pyramids.py index 58efc65..5c20ea2 100644 --- a/tests/test_pyramids.py +++ b/tests/test_pyramids.py @@ -1,26 +1,15 @@ import numpy as np import pytest -import xarray as xr from zarr.storage import MemoryStore from ndpyramid import ( pyramid_coarsen, pyramid_create, - pyramid_regrid, pyramid_reproject, - pyramid_resample, ) -from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds from ndpyramid.testing import verify_bounds -@pytest.fixture -def temperature(): - ds = xr.tutorial.open_dataset('air_temperature') - ds['air'].encoding = {} - return ds - - def test_xarray_coarsened_pyramid(temperature, benchmark): factors = [4, 2, 1] pyramid = benchmark( @@ -70,63 +59,38 @@ def test_reprojected_pyramid(temperature, benchmark): pyramid.to_zarr(MemoryStore()) -def test_reprojected_pyramid_fill(temperature, benchmark): - """ - Test for https://github.com/carbonplan/ndpyramid/issues/93. - """ - pytest.importorskip('rioxarray') - temperature = temperature.rio.write_crs('EPSG:4326') - pyramid = benchmark(lambda: pyramid_reproject(temperature, levels=1)) - assert np.isnan(pyramid['0'].air.isel(time=0, x=0, y=0).values) - - -@pytest.mark.parametrize('resampling', ['bilinear', 'nearest']) -def test_resampled_pyramid(temperature, benchmark, resampling): - pytest.importorskip('pyresample') +def test_reprojected_pyramid_resampling_dict(dataset_3d, benchmark): pytest.importorskip('rioxarray') levels = 2 - temperature = temperature.rio.write_crs('EPSG:4326') - temperature = temperature.transpose('time', 'lat', 'lon') - # import pdb; pdb.set_trace() - pyramid = benchmark( - lambda: pyramid_resample( - temperature, levels=levels, x='lon', y='lat', resampling=resampling + lambda: pyramid_reproject( + dataset_3d, levels=levels, resampling={'ones': 'bilinear', 'rand': 'nearest'} ) ) verify_bounds(pyramid) - assert pyramid.ds.attrs['multiscales'] - assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels - assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' - assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + assert pyramid.attrs['multiscales'][0]['metadata']['kwargs']['resampling'] == { + 'ones': 'bilinear', + 'rand': 'nearest', + } pyramid.to_zarr(MemoryStore()) -@pytest.mark.xfail(reseason='Differences between rasterio and pyresample to be investigated') -def test_reprojected_resample_pyramid_values(temperature, benchmark): +def test_reprojected_pyramid_clear_attrs(dataset_3d, benchmark): pytest.importorskip('rioxarray') levels = 2 - temperature = temperature.rio.write_crs('EPSG:4326') - temperature = temperature.chunk({'time': 10, 'lat': 10, 'lon': 10}) - reprojected = benchmark( - lambda: pyramid_reproject(temperature, levels=levels, resampling='nearest') - ) - resampled = benchmark( - lambda: pyramid_resample( - temperature, levels=levels, x='lon', y='lat', resampling='nearest_neighbour' - ) - ) - xr.testing.assert_allclose(reprojected['0'].ds, resampled['0'].ds) - xr.testing.assert_allclose(reprojected['1'].ds, resampled['1'].ds) + pyramid = benchmark(lambda: pyramid_reproject(dataset_3d, levels=levels, clear_attrs=True)) + verify_bounds(pyramid) + for _, da in pyramid['0'].ds.items(): + assert not da.attrs + pyramid.to_zarr(MemoryStore()) -def test_resampled_pyramid_2D(temperature, benchmark): - pytest.importorskip('pyresample') +def test_reprojected_pyramid_4d(dataset_4d, benchmark): pytest.importorskip('rioxarray') levels = 2 - temperature = temperature.rio.write_crs('EPSG:4326') - temperature = temperature.isel(time=0).drop_vars('time') - pyramid = benchmark(lambda: pyramid_resample(temperature, levels=levels, x='lon', y='lat')) + with pytest.raises(Exception): + pyramid = benchmark(lambda: pyramid_reproject(dataset_4d, levels=levels)) + pyramid = benchmark(lambda: pyramid_reproject(dataset_4d, levels=levels, extra_dim='band')) verify_bounds(pyramid) assert pyramid.ds.attrs['multiscales'] assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels @@ -135,75 +99,11 @@ def test_resampled_pyramid_2D(temperature, benchmark): pyramid.to_zarr(MemoryStore()) -def test_resampled_pyramid_fill(temperature, benchmark): +def test_reprojected_pyramid_fill(temperature, benchmark): """ Test for https://github.com/carbonplan/ndpyramid/issues/93. """ - pytest.importorskip('pyresample') pytest.importorskip('rioxarray') temperature = temperature.rio.write_crs('EPSG:4326') - pyramid = benchmark(lambda: pyramid_resample(temperature, levels=1, x='lon', y='lat')) + pyramid = benchmark(lambda: pyramid_reproject(temperature, levels=1)) assert np.isnan(pyramid['0'].air.isel(time=0, x=0, y=0).values) - - -@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}]) -def test_regridded_pyramid(temperature, regridder_apply_kws, benchmark): - pytest.importorskip('xesmf') - pyramid = benchmark( - lambda: pyramid_regrid( - temperature, levels=2, regridder_apply_kws=regridder_apply_kws, other_chunks={'time': 2} - ) - ) - verify_bounds(pyramid) - assert pyramid.ds.attrs['multiscales'] - assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' - assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' - expected_attrs = ( - temperature['air'].attrs - if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs') - else {} - ) - assert pyramid['0'].ds.air.attrs == expected_attrs - assert pyramid['1'].ds.air.attrs == expected_attrs - pyramid.to_zarr(MemoryStore()) - - -def test_regridded_pyramid_with_weights(temperature, benchmark): - pytest.importorskip('xesmf') - levels = 2 - weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels) - pyramid = benchmark( - lambda: pyramid_regrid( - temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2} - ) - ) - verify_bounds(pyramid) - assert pyramid.ds.attrs['multiscales'] - assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels - pyramid.to_zarr(MemoryStore()) - - -@pytest.mark.parametrize('projection', ['web-mercator', 'equidistant-cylindrical']) -def test_make_grid_ds(projection, benchmark): - - grid = benchmark(lambda: make_grid_ds(0, pixels_per_tile=8, projection=projection)) - lon_vals = grid.lon_b.values - assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001) - assert ( - grid.attrs['title'] == 'Web Mercator Grid' - if projection == 'web-mercator' - else 'Equidistant Cylindrical Grid' - ) - - -@pytest.mark.parametrize('levels', [1, 2]) -@pytest.mark.parametrize('method', ['bilinear', 'conservative']) -def test_generate_weights_pyramid(temperature, levels, method, benchmark): - pytest.importorskip('xesmf') - weights_pyramid = benchmark( - lambda: generate_weights_pyramid(temperature.isel(time=0), levels, method=method) - ) - assert weights_pyramid.ds.attrs['levels'] == levels - assert weights_pyramid.ds.attrs['regrid_method'] == method - assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'} - assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs