generated from carbonplan/python-project-template
-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
327 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
import xarray as xr | ||
|
||
|
||
@pytest.fixture | ||
def temperature(): | ||
ds = xr.tutorial.open_dataset('air_temperature') | ||
ds['air'].encoding = {} | ||
return ds | ||
|
||
|
||
@pytest.fixture() | ||
def dataset_4d(non_dim_coords=False, start='2010-01-01'): | ||
""" | ||
Return a synthetic random Xarray dataset. | ||
Modified from https://github.com/pangeo-forge/pangeo-forge-recipes/blob/fbaaf31b6f278418bd9ba6750ffcdb9874409196/tests/data_generation.py#L6-L45 | ||
""" | ||
nb, nt, ny, nx = 2, 10, 740, 1440 | ||
|
||
time = pd.date_range(start=start, periods=nt, freq='D') | ||
x = (np.arange(nx) + 0.5) * 360 / nx - 180 | ||
x_attrs = {'units': 'degrees_east', 'xg_name': 'xgitude'} | ||
y = (np.arange(ny) + 0.5) * 180 / ny - 90 | ||
y_attrs = {'units': 'degrees_north', 'xg_name': 'yitude'} | ||
foo = np.ones((nb, nt, ny, nx)) | ||
foo_attrs = {'xg_name': 'Fantastic Foo'} | ||
bar = np.random.rand(nb, nt, ny, nx) | ||
bar_attrs = {'xg_name': 'Beautiful Bar'} | ||
band = np.arange(nb) | ||
dims = ('band', 'time', 'y', 'x') | ||
coords = { | ||
'band': ('band', band), | ||
'time': ('time', time), | ||
'y': ('y', y, y_attrs), | ||
'x': ('x', x, x_attrs), | ||
} | ||
if non_dim_coords: | ||
coords['timestep'] = ('time', np.arange(nt)) | ||
coords['baz'] = (('y', 'x'), np.random.rand(ny, nx)) | ||
|
||
ds = xr.Dataset( | ||
{'rand': (dims, bar, bar_attrs), 'ones': (dims, foo, foo_attrs)}, | ||
coords=coords, | ||
attrs={'conventions': 'CF 1.6'}, | ||
) | ||
|
||
# Add time coord encoding | ||
# Remove "%H:%M:%s" as it will be dropped when time is 0:0:0 | ||
ds.time.encoding = { | ||
'units': f"days since {time[0].strftime('%Y-%m-%d')}", | ||
'calendar': 'proleptic_gregorian', | ||
} | ||
ds = ds.rio.write_crs('EPSG:4326') | ||
|
||
return ds | ||
|
||
|
||
@pytest.fixture() | ||
def dataset_3d(non_dim_coords=False, start='2010-01-01'): | ||
""" | ||
Return a synthetic random Xarray dataset. | ||
Modified from https://github.com/pangeo-forge/pangeo-forge-recipes/blob/fbaaf31b6f278418bd9ba6750ffcdb9874409196/tests/data_generation.py#L6-L45 | ||
""" | ||
nt, ny, nx = 10, 740, 1440 | ||
|
||
time = pd.date_range(start=start, periods=nt, freq='D') | ||
x = (np.arange(nx) + 0.5) * 360 / nx - 180 | ||
x_attrs = {'units': 'degrees_east', 'xg_name': 'xgitude'} | ||
y = (np.arange(ny) + 0.5) * 180 / ny - 90 | ||
y_attrs = {'units': 'degrees_north', 'xg_name': 'yitude'} | ||
foo = np.ones((nt, ny, nx)) | ||
foo_attrs = {'xg_name': 'Fantastic Foo'} | ||
bar = np.random.rand(nt, ny, nx) | ||
bar_attrs = {'xg_name': 'Beautiful Bar'} | ||
dims = ('time', 'y', 'x') | ||
coords = { | ||
'time': ('time', time), | ||
'y': ('y', y, y_attrs), | ||
'x': ('x', x, x_attrs), | ||
} | ||
if non_dim_coords: | ||
coords['timestep'] = ('time', np.arange(nt)) | ||
coords['baz'] = (('y', 'x'), np.random.rand(ny, nx)) | ||
|
||
ds = xr.Dataset( | ||
{'rand': (dims, bar, bar_attrs), 'ones': (dims, foo, foo_attrs)}, | ||
coords=coords, | ||
attrs={'conventions': 'CF 1.6'}, | ||
) | ||
|
||
# Add time coord encoding | ||
# Remove "%H:%M:%s" as it will be dropped when time is 0:0:0 | ||
ds.time.encoding = { | ||
'units': f"days since {time[0].strftime('%Y-%m-%d')}", | ||
'calendar': 'proleptic_gregorian', | ||
} | ||
ds = ds.rio.write_crs('EPSG:4326') | ||
ds = ds.chunk({'x': 100, 'y': 100, 'time': 10}) | ||
|
||
return ds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
import pytest | ||
from zarr.storage import MemoryStore | ||
|
||
from ndpyramid import ( | ||
pyramid_regrid, | ||
) | ||
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds | ||
from ndpyramid.testing import verify_bounds | ||
|
||
|
||
@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}]) | ||
def test_regridded_pyramid(temperature, regridder_apply_kws, benchmark): | ||
pytest.importorskip('xesmf') | ||
pyramid = benchmark( | ||
lambda: pyramid_regrid( | ||
temperature, levels=2, regridder_apply_kws=regridder_apply_kws, other_chunks={'time': 2} | ||
) | ||
) | ||
verify_bounds(pyramid) | ||
assert pyramid.ds.attrs['multiscales'] | ||
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
expected_attrs = ( | ||
temperature['air'].attrs | ||
if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs') | ||
else {} | ||
) | ||
assert pyramid['0'].ds.air.attrs == expected_attrs | ||
assert pyramid['1'].ds.air.attrs == expected_attrs | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
def test_regridded_pyramid_with_weights(temperature, benchmark): | ||
pytest.importorskip('xesmf') | ||
levels = 2 | ||
weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels) | ||
pyramid = benchmark( | ||
lambda: pyramid_regrid( | ||
temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2} | ||
) | ||
) | ||
verify_bounds(pyramid) | ||
assert pyramid.ds.attrs['multiscales'] | ||
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
@pytest.mark.parametrize('projection', ['web-mercator', 'equidistant-cylindrical']) | ||
def test_make_grid_ds(projection, benchmark): | ||
|
||
grid = benchmark(lambda: make_grid_ds(0, pixels_per_tile=8, projection=projection)) | ||
lon_vals = grid.lon_b.values | ||
assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001) | ||
assert ( | ||
grid.attrs['title'] == 'Web Mercator Grid' | ||
if projection == 'web-mercator' | ||
else 'Equidistant Cylindrical Grid' | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('levels', [1, 2]) | ||
@pytest.mark.parametrize('method', ['bilinear', 'conservative']) | ||
def test_generate_weights_pyramid(temperature, levels, method, benchmark): | ||
pytest.importorskip('xesmf') | ||
weights_pyramid = benchmark( | ||
lambda: generate_weights_pyramid(temperature.isel(time=0), levels, method=method) | ||
) | ||
assert weights_pyramid.ds.attrs['levels'] == levels | ||
assert weights_pyramid.ds.attrs['regrid_method'] == method | ||
assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'} | ||
assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
from zarr.storage import MemoryStore | ||
|
||
from ndpyramid import ( | ||
pyramid_reproject, | ||
pyramid_resample, | ||
) | ||
from ndpyramid.testing import verify_bounds | ||
|
||
|
||
@pytest.mark.parametrize('resampling', ['bilinear', 'nearest']) | ||
def test_resampled_pyramid(temperature, benchmark, resampling): | ||
pytest.importorskip('pyresample') | ||
pytest.importorskip('rioxarray') | ||
levels = 2 | ||
temperature = temperature.rio.write_crs('EPSG:4326') | ||
temperature = temperature.transpose('time', 'lat', 'lon') | ||
# import pdb; pdb.set_trace() | ||
|
||
pyramid = benchmark( | ||
lambda: pyramid_resample( | ||
temperature, levels=levels, x='lon', y='lat', resampling=resampling | ||
) | ||
) | ||
verify_bounds(pyramid) | ||
assert pyramid.ds.attrs['multiscales'] | ||
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels | ||
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
@pytest.mark.parametrize('method', ['bilinear', 'nearest', {'air': 'nearest'}]) | ||
def test_resampled_pyramid_2D(temperature, method, benchmark): | ||
pytest.importorskip('pyresample') | ||
pytest.importorskip('rioxarray') | ||
levels = 2 | ||
temperature = temperature.rio.write_crs('EPSG:4326') | ||
temperature = temperature.isel(time=0).drop_vars('time') | ||
pyramid = benchmark( | ||
lambda: pyramid_resample(temperature, levels=levels, x='lon', y='lat', resampling=method) | ||
) | ||
verify_bounds(pyramid) | ||
assert pyramid.ds.attrs['multiscales'] | ||
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels | ||
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
def test_reprojected_pyramid_clear_attrs(dataset_3d, benchmark): | ||
pytest.importorskip('rioxarray') | ||
levels = 2 | ||
pyramid = benchmark( | ||
lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y', clear_attrs=True) | ||
) | ||
verify_bounds(pyramid) | ||
for _, da in pyramid['0'].ds.items(): | ||
assert not da.attrs | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
@pytest.mark.xfail(reseason='Need to fix handling of other_chunks') | ||
def test_reprojected_pyramid_other_chunks(dataset_3d, benchmark): | ||
pytest.importorskip('rioxarray') | ||
levels = 2 | ||
pyramid = benchmark( | ||
lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y', other_chunks={'time': 5}) | ||
) | ||
verify_bounds(pyramid) | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
def test_resampled_pyramid_without_CF(dataset_3d, benchmark): | ||
pytest.importorskip('pyresample') | ||
pytest.importorskip('rioxarray') | ||
levels = 2 | ||
pyramid = benchmark(lambda: pyramid_resample(dataset_3d, levels=levels, x='x', y='y')) | ||
verify_bounds(pyramid) | ||
assert pyramid.ds.attrs['multiscales'] | ||
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels | ||
assert pyramid.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
assert pyramid['0'].attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' | ||
pyramid.to_zarr(MemoryStore()) | ||
|
||
|
||
def test_resampled_pyramid_fill(temperature, benchmark): | ||
""" | ||
Test for https://github.com/carbonplan/ndpyramid/issues/93. | ||
""" | ||
pytest.importorskip('pyresample') | ||
pytest.importorskip('rioxarray') | ||
temperature = temperature.rio.write_crs('EPSG:4326') | ||
pyramid = benchmark(lambda: pyramid_resample(temperature, levels=1, x='lon', y='lat')) | ||
assert np.isnan(pyramid['0'].air.isel(time=0, x=0, y=0).values) | ||
|
||
|
||
@pytest.mark.xfail(reseason='Differences between rasterio and pyresample to be investigated') | ||
def test_reprojected_resample_pyramid_values(temperature, benchmark): | ||
pytest.importorskip('rioxarray') | ||
levels = 2 | ||
temperature = temperature.rio.write_crs('EPSG:4326') | ||
temperature = temperature.chunk({'time': 10, 'lat': 10, 'lon': 10}) | ||
reprojected = benchmark( | ||
lambda: pyramid_reproject(temperature, levels=levels, resampling='nearest') | ||
) | ||
resampled = benchmark( | ||
lambda: pyramid_resample( | ||
temperature, levels=levels, x='lon', y='lat', resampling='nearest_neighbour' | ||
) | ||
) | ||
xr.testing.assert_allclose(reprojected['0'].ds, resampled['0'].ds) | ||
xr.testing.assert_allclose(reprojected['1'].ds, resampled['1'].ds) |
Oops, something went wrong.