Skip to content

Commit

Permalink
single level func prototype -- non refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
norlandrhagen committed Dec 12, 2023
1 parent 55fe090 commit e24aeb4
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# flake8: noqa

from .core import pyramid_coarsen, pyramid_reproject
from .core import pyramid_coarsen, pyramid_reproject, reproject_single_level
from .regrid import pyramid_regrid
from ._version import __version__
107 changes: 106 additions & 1 deletion ndpyramid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import datatree as dt
import xarray as xr

from ._version import __version__
from .common import Projection
from .utils import add_metadata_and_zarr_encoding, get_version, multiscales_template
from .utils import (
add_metadata_and_zarr_encoding,
get_version,
multiscales_template,
set_zarr_encoding,
)


def pyramid_coarsen(
Expand Down Expand Up @@ -53,6 +59,105 @@ def pyramid_coarsen(
plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)



# single_level branch
def reproject_single_level(
ds: xr.Dataset,
*,
projection:typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator',
level: int = None,
pixels_per_tile: int = 128,
other_chunks: dict = None,
resampling: str | dict = 'average',
extra_dim: str = None,
) -> dt.DataTree:


import rioxarray # noqa: F401
from rasterio.warp import Resampling

# multiscales spec
save_kwargs = {'levels': level, 'pixels_per_tile': pixels_per_tile}
attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in [level]],
type='reduce',
method='pyramid_reproject',
version=get_version(),
kwargs=save_kwargs,
)
}

# Convert resampling from string to dictionary if necessary
if isinstance(resampling, str):
resampling_dict = defaultdict(lambda: resampling)
else:
resampling_dict = resampling

projection_model = Projection(name=projection)

# set up pyramid
plevels = {}

# pyramid data

lkey = str(level)
dim = 2**level * pixels_per_tile
dst_transform = projection_model.transform(dim=dim)

def reproject(da, var):
return da.rio.reproject(
projection_model._crs,
resampling=Resampling[resampling_dict[var]],
shape=(dim, dim),
transform=dst_transform,
)
# create the data array for each level

plevels[lkey] = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if len(da.shape) == 4:
# if extra_dim is not specified, raise an error
if extra_dim is None:
raise ValueError("must specify 'extra_dim' to iterate over 4d data")
da_all = []
for index in ds[extra_dim]:
# reproject each index of the 4th dimension
da_reprojected = reproject(da.sel({extra_dim: index}), k)
da_all.append(da_reprojected)
plevels[lkey][k] = xr.concat(da_all, ds[extra_dim])
else:
# if the data array is not 4D, just reproject it

plevels[lkey][k] = reproject(da, k)
level_ds = plevels[lkey]
level_ds.attrs = attrs


chunks = {'x': pixels_per_tile, 'y': pixels_per_tile}

if other_chunks is not None:
chunks |= other_chunks
level_ds.attrs['multiscales'][0]['metadata']['kwargs']['pixels_per_tile'] = pixels_per_tile
if projection:
level_ds.attrs['multiscales'][0]['datasets'][0]['crs'] = projection_model._crs
# set dataset chunks
level_ds = level_ds.chunk(chunks)

# set dataset encoding

level_ds = set_zarr_encoding(
level_ds, codec_config={'id': 'zlib', 'level': 1}, float_dtype='float32'
)

# set global metadata
level_ds.attrs.update({'title': 'multiscale data pyramid', 'version': __version__})
return level_ds




# single_level branch
def pyramid_reproject(
ds: xr.Dataset,
Expand Down

0 comments on commit e24aeb4

Please sign in to comment.