Skip to content

Commit

Permalink
Merge branch 'main' into new-demo
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrjones committed Jun 21, 2024
2 parents a00347a + 5ab0f39 commit e27173e
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 127 deletions.
2 changes: 1 addition & 1 deletion ndpyramid/reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


def _da_reproject(da, *, dim, crs, resampling, transform):
def _da_reproject(da: xr.DataArray, *, dim: int, crs: str, resampling: str, transform):
if da.encoding.get('_FillValue') is None and np.issubdtype(da.dtype, np.floating):
da.encoding['_FillValue'] = np.nan
return da.rio.reproject(
Expand Down
19 changes: 15 additions & 4 deletions ndpyramid/resample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations # noqa: F401

import typing
import warnings
from collections import defaultdict

Expand All @@ -16,8 +17,18 @@
multiscales_template,
)

ResamplingOptions = typing.Literal['bilinear', 'nearest']

def _da_resample(da, *, dim, projection_model, pixels_per_tile, other_chunk, resampling):

def _da_resample(
da: xr.DataArray,
*,
dim: int,
projection_model: Projection,
pixels_per_tile: int,
other_chunk: int,
resampling: ResamplingOptions,
):
try:
from pyresample.area_config import create_area_def
from pyresample.future.resamplers.resampler import (
Expand All @@ -39,7 +50,7 @@ def _da_resample(da, *, dim, projection_model, pixels_per_tile, other_chunk, res
da.encoding['_FillValue'] = np.nan
if resampling == 'bilinear':
fun = block_bilinear_interpolator
elif resampling in ['nearest_neighbor' 'nearest_neighbour', 'nn', 'nearest']:
elif resampling == 'nearest':
fun = block_nn_interpolator
else:
raise ValueError(f"Unrecognized interpolation method {resampling} for gradient resampling.")
Expand Down Expand Up @@ -101,7 +112,7 @@ def level_resample(
level: int,
pixels_per_tile: int = 128,
other_chunks: dict = None,
resampling: str | dict = 'bilinear',
resampling: ResamplingOptions | dict = 'bilinear',
clear_attrs: bool = False,
) -> xr.Dataset:
"""Create a level of a multiscale pyramid of a dataset via resampling.
Expand Down Expand Up @@ -196,7 +207,7 @@ def pyramid_resample(
levels: int = None,
pixels_per_tile: int = 128,
other_chunks: dict = None,
resampling: str | dict = 'bilinear',
resampling: ResamplingOptions | dict = 'bilinear',
clear_attrs: bool = False,
) -> dt.DataTree:
"""Create a multiscale pyramid of a dataset via resampling.
Expand Down
6 changes: 3 additions & 3 deletions ndpyramid/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
def _bounds(ds):
left = ds.x[0] - (ds.x[1] - ds.x[0]) / 2
right = ds.x[-1] + (ds.x[-1] - ds.x[-2]) / 2
top = ds.y[0] + (ds.y[1] - ds.y[0]) / 2
bottom = ds.y[-1] - (ds.y[-1] - ds.y[-2]) / 2
top = ds.y[0] - (ds.y[1] - ds.y[0]) / 2
bottom = ds.y[-1] + (ds.y[-1] - ds.y[-2]) / 2
return (left.data, bottom.data, right.data, top.data)


Expand All @@ -29,6 +29,6 @@ def verify_bounds(pyramid):
pyramid[level].ds,
template=pyramid[level].ds,
kwargs={'zoom': int(level)},
)
).compute()
else:
raise ValueError('Tile boundary verification has only been implemented for EPSG:3857')
102 changes: 102 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr


@pytest.fixture
def temperature():
ds = xr.tutorial.open_dataset('air_temperature')
ds['air'].encoding = {}
return ds


@pytest.fixture()
def dataset_4d(non_dim_coords=False, start='2010-01-01'):
"""
Return a synthetic random Xarray dataset.
Modified from https://github.com/pangeo-forge/pangeo-forge-recipes/blob/fbaaf31b6f278418bd9ba6750ffcdb9874409196/tests/data_generation.py#L6-L45
"""
nb, nt, ny, nx = 2, 10, 740, 1440

time = pd.date_range(start=start, periods=nt, freq='D')
x = (np.arange(nx) + 0.5) * 360 / nx - 180
x_attrs = {'units': 'degrees_east', 'xg_name': 'xgitude'}
y = (np.arange(ny) + 0.5) * 180 / ny - 90
y_attrs = {'units': 'degrees_north', 'xg_name': 'yitude'}
foo = np.ones((nb, nt, ny, nx))
foo_attrs = {'xg_name': 'Fantastic Foo'}
bar = np.random.rand(nb, nt, ny, nx)
bar_attrs = {'xg_name': 'Beautiful Bar'}
band = np.arange(nb)
dims = ('band', 'time', 'y', 'x')
coords = {
'band': ('band', band),
'time': ('time', time),
'y': ('y', y, y_attrs),
'x': ('x', x, x_attrs),
}
if non_dim_coords:
coords['timestep'] = ('time', np.arange(nt))
coords['baz'] = (('y', 'x'), np.random.rand(ny, nx))

ds = xr.Dataset(
{'rand': (dims, bar, bar_attrs), 'ones': (dims, foo, foo_attrs)},
coords=coords,
attrs={'conventions': 'CF 1.6'},
)

# Add time coord encoding
# Remove "%H:%M:%s" as it will be dropped when time is 0:0:0
ds.time.encoding = {
'units': f"days since {time[0].strftime('%Y-%m-%d')}",
'calendar': 'proleptic_gregorian',
}
ds = ds.rio.write_crs('EPSG:4326')

return ds


@pytest.fixture()
def dataset_3d(non_dim_coords=False, start='2010-01-01'):
"""
Return a synthetic random Xarray dataset.
Modified from https://github.com/pangeo-forge/pangeo-forge-recipes/blob/fbaaf31b6f278418bd9ba6750ffcdb9874409196/tests/data_generation.py#L6-L45
"""
nt, ny, nx = 10, 740, 1440

time = pd.date_range(start=start, periods=nt, freq='D')
x = (np.arange(nx) + 0.5) * 360 / nx - 180
x_attrs = {'units': 'degrees_east', 'xg_name': 'xgitude'}
y = (np.arange(ny) + 0.5) * 180 / ny - 90
y_attrs = {'units': 'degrees_north', 'xg_name': 'yitude'}
foo = np.ones((nt, ny, nx))
foo_attrs = {'xg_name': 'Fantastic Foo'}
bar = np.random.rand(nt, ny, nx)
bar_attrs = {'xg_name': 'Beautiful Bar'}
dims = ('time', 'y', 'x')
coords = {
'time': ('time', time),
'y': ('y', y, y_attrs),
'x': ('x', x, x_attrs),
}
if non_dim_coords:
coords['timestep'] = ('time', np.arange(nt))
coords['baz'] = (('y', 'x'), np.random.rand(ny, nx))

ds = xr.Dataset(
{'rand': (dims, bar, bar_attrs), 'ones': (dims, foo, foo_attrs)},
coords=coords,
attrs={'conventions': 'CF 1.6'},
)

# Add time coord encoding
# Remove "%H:%M:%s" as it will be dropped when time is 0:0:0
ds.time.encoding = {
'units': f"days since {time[0].strftime('%Y-%m-%d')}",
'calendar': 'proleptic_gregorian',
}
ds = ds.rio.write_crs('EPSG:4326')
ds = ds.chunk({'x': 100, 'y': 100, 'time': 10})

return ds
72 changes: 72 additions & 0 deletions tests/test_pyramid_regrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
import pytest
from zarr.storage import MemoryStore

from ndpyramid import (
pyramid_regrid,
)
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds
from ndpyramid.testing import verify_bounds


@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}])
def test_regridded_pyramid(temperature, regridder_apply_kws, benchmark):
pytest.importorskip('xesmf')
pyramid = benchmark(
lambda: pyramid_regrid(
temperature, levels=2, regridder_apply_kws=regridder_apply_kws, other_chunks={'time': 2}
)
)
verify_bounds(pyramid)
assert pyramid.ds.attrs['multiscales']
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
expected_attrs = (
temperature['air'].attrs
if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs')
else {}
)
assert pyramid['0'].ds.air.attrs == expected_attrs
assert pyramid['1'].ds.air.attrs == expected_attrs
pyramid.to_zarr(MemoryStore())


def test_regridded_pyramid_with_weights(temperature, benchmark):
pytest.importorskip('xesmf')
levels = 2
weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels)
pyramid = benchmark(
lambda: pyramid_regrid(
temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2}
)
)
verify_bounds(pyramid)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
pyramid.to_zarr(MemoryStore())


@pytest.mark.parametrize('projection', ['web-mercator', 'equidistant-cylindrical'])
def test_make_grid_ds(projection, benchmark):

grid = benchmark(lambda: make_grid_ds(0, pixels_per_tile=8, projection=projection))
lon_vals = grid.lon_b.values
assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001)
assert (
grid.attrs['title'] == 'Web Mercator Grid'
if projection == 'web-mercator'
else 'Equidistant Cylindrical Grid'
)


@pytest.mark.parametrize('levels', [1, 2])
@pytest.mark.parametrize('method', ['bilinear', 'conservative'])
def test_generate_weights_pyramid(temperature, levels, method, benchmark):
pytest.importorskip('xesmf')
weights_pyramid = benchmark(
lambda: generate_weights_pyramid(temperature.isel(time=0), levels, method=method)
)
assert weights_pyramid.ds.attrs['levels'] == levels
assert weights_pyramid.ds.attrs['regrid_method'] == method
assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'}
assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs
115 changes: 115 additions & 0 deletions tests/test_pyramid_resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
import pytest
import xarray as xr
from zarr.storage import MemoryStore

from ndpyramid import (
pyramid_reproject,
pyramid_resample,
)
from ndpyramid.testing import verify_bounds


@pytest.mark.parametrize('resampling', ['bilinear', 'nearest'])
def test_resampled_pyramid(temperature, benchmark, resampling):
pytest.importorskip('pyresample')
pytest.importorskip('rioxarray')
levels = 2
temperature = temperature.rio.write_crs('EPSG:4326')
temperature = temperature.transpose('time', 'lat', 'lon')
# import pdb; pdb.set_trace()

pyramid = benchmark(
lambda: pyramid_resample(
temperature, levels=levels, x='lon', y='lat', resampling=resampling
)
)
verify_bounds(pyramid)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
pyramid.to_zarr(MemoryStore())


@pytest.mark.parametrize('method', ['bilinear', 'nearest', {'air': 'nearest'}])
def test_resampled_pyramid_2D(temperature, method, benchmark):
pytest.importorskip('pyresample')
pytest.importorskip('rioxarray')
levels = 2
temperature = temperature.rio.write_crs('EPSG:4326')
temperature = temperature.isel(time=0).drop_vars('time')
pyramid = benchmark(
lambda: pyramid_resample(temperature, levels=levels, x='lon', y='lat', resampling=method)
)
verify_bounds(pyramid)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
pyramid.to_zarr(MemoryStore())


def test_reprojected_pyramid_clear_attrs(dataset_3d, benchmark):
pytest.importorskip('rioxarray')
levels = 2
pyramid = benchmark(
lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y', clear_attrs=True)
)
verify_bounds(pyramid)
for _, da in pyramid['0'].ds.items():
assert not da.attrs
pyramid.to_zarr(MemoryStore())


@pytest.mark.xfail(reseason='Need to fix handling of other_chunks')
def test_reprojected_pyramid_other_chunks(dataset_3d, benchmark):
pytest.importorskip('rioxarray')
levels = 2
pyramid = benchmark(
lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y', other_chunks={'time': 5})
)
verify_bounds(pyramid)
pyramid.to_zarr(MemoryStore())


def test_resampled_pyramid_without_CF(dataset_3d, benchmark):
pytest.importorskip('pyresample')
pytest.importorskip('rioxarray')
levels = 2
pyramid = benchmark(lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y'))
verify_bounds(pyramid)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
pyramid.to_zarr(MemoryStore())


def test_resampled_pyramid_fill(temperature, benchmark):
"""
Test for https://github.com/carbonplan/ndpyramid/issues/93.
"""
pytest.importorskip('pyresample')
pytest.importorskip('rioxarray')
temperature = temperature.rio.write_crs('EPSG:4326')
pyramid = benchmark(lambda: pyramid_resample(temperature, levels=1, x='lon', y='lat'))
assert np.isnan(pyramid['0'].air.isel(time=0, x=0, y=0).values)


@pytest.mark.xfail(reseason='Differences between rasterio and pyresample to be investigated')
def test_reprojected_resample_pyramid_values(temperature, benchmark):
pytest.importorskip('rioxarray')
levels = 2
temperature = temperature.rio.write_crs('EPSG:4326')
temperature = temperature.chunk({'time': 10, 'lat': 10, 'lon': 10})
reprojected = benchmark(
lambda: pyramid_reproject(temperature, levels=levels, resampling='nearest')
)
resampled = benchmark(
lambda: pyramid_resample(
temperature, levels=levels, x='lon', y='lat', resampling='nearest_neighbour'
)
)
xr.testing.assert_allclose(reprojected['0'].ds, resampled['0'].ds)
xr.testing.assert_allclose(reprojected['1'].ds, resampled['1'].ds)
Loading

0 comments on commit e27173e

Please sign in to comment.