Skip to content

Commit

Permalink
Merge pull request #768 from pnuu/feature-bucket-resampling
Browse files Browse the repository at this point in the history
Expose bucket resampling from Pyresample
  • Loading branch information
mraspaud authored Sep 19, 2019
2 parents de39762 + b32428b commit 78d9aa7
Show file tree
Hide file tree
Showing 2 changed files with 418 additions and 0 deletions.
181 changes: 181 additions & 0 deletions satpy/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
"ewa", "Elliptical Weighted Averaging", :class:`~satpy.resample.EWAResampler`
"native", "Native", :class:`~satpy.resample.NativeResampler`
"bilinear", "Bilinear", :class:`~satpy.resample.BilinearResampler`
"bucket_avg", "Average Bucket Resampling", :class:`~satpy.resample.BucketAvg`
"bucket_sum", "Sum Bucket Resampling", :class:`~satpy.resample.BucketSum`
"bucket_count", "Count Bucket Resampling", :class:`~satpy.resample.BucketCount`
"bucket_fraction", "Fraction Bucket Resampling", :class:`~satpy.resample.BucketFraction`
The resampling algorithm used can be specified with the ``resampler`` keyword
argument and defaults to ``nearest``:
Expand Down Expand Up @@ -141,6 +145,7 @@
from pyresample.geometry import SwathDefinition
from pyresample.kd_tree import XArrayResamplerNN
from pyresample.bilinear.xarr import XArrayResamplerBilinear
from pyresample import bucket
from satpy import CHUNK_SIZE
from satpy.config import config_search_paths, get_config_path

Expand Down Expand Up @@ -1026,11 +1031,187 @@ def compute(self, data, expand=True, **kwargs):
return update_resampled_coords(data, new_data, target_geo_def)


class BucketResamplerBase(BaseResampler):
"""Base class for bucket resampling which implements averaging.
"""

def __init__(self, source_geo_def, target_geo_def):
super(BucketResamplerBase, self).__init__(source_geo_def, target_geo_def)
self.resampler = None

def precompute(self, **kwargs):
"""Create X and Y indices and store them for later use."""
LOG.debug("Initializing bucket resampler.")
source_lons, source_lats = self.source_geo_def.get_lonlats(
chunks=CHUNK_SIZE)
self.resampler = bucket.BucketResampler(self.target_geo_def,
source_lons,
source_lats)

def compute(self, data, **kwargs):
"""Call the resampling."""
raise NotImplementedError("Use the sub-classes")

def resample(self, data, **kwargs):
"""Resample `data` by calling `precompute` and `compute` methods.
Args:
data (xarray.DataArray): Data to be resampled
Returns (xarray.DataArray): Data resampled to the target area
"""
self.precompute(**kwargs)
attrs = data.attrs.copy()
data_arr = data.data
if data.ndim == 3 and data.dims[0] == 'bands':
dims = ('bands', 'y', 'x')
# Both one and two dimensional input data results in 2D output
elif data.ndim in (1, 2):
dims = ('y', 'x')
else:
dims = data.dims
result = self.compute(data_arr, **kwargs)
coords = {}
if 'bands' in data.coords:
coords['bands'] = data.coords['bands']
# Fractions are returned in a dict
elif isinstance(result, dict):
coords['categories'] = sorted(result.keys())
dims = ('categories', 'y', 'x')
new_result = []
for cat in coords['categories']:
new_result.append(result[cat])
result = da.stack(new_result)
if result.ndim > len(dims):
result = da.squeeze(result)

# Adjust some attributes
if "BucketFraction" in str(self):
attrs['units'] = ''
attrs['calibration'] = ''
attrs['standard_name'] = 'area_fraction'
elif "BucketCount" in str(self):
attrs['units'] = ''
attrs['calibration'] = ''
attrs['standard_name'] = 'number_of_observations'

result = xr.DataArray(result, dims=dims, coords=coords,
attrs=attrs)

return result


class BucketAvg(BucketResamplerBase):
"""Class for averaging bucket resampling.
Bucket resampling calculates the average of all the values that
are closest to each bin and inside the target area.
Parameters
----------
fill_value : float (default: np.nan)
Fill value for missing data
mask_all_nans : boolean (default: False)
Mask all locations with all-NaN values
"""

def compute(self, data, fill_value=np.nan, mask_all_nan=False, **kwargs):
"""Call the resampling."""
results = []
if data.ndim == 3:
for i in range(data.shape[0]):
res = self.resampler.get_average(data[i, :, :],
fill_value=fill_value,
mask_all_nan=mask_all_nan)
results.append(res)
else:
res = self.resampler.get_average(data, fill_value=fill_value,
mask_all_nan=mask_all_nan)
results.append(res)

return da.stack(results)


class BucketSum(BucketResamplerBase):
"""Class for bucket resampling which implements accumulation (sum).
This resampler calculates the cumulative sum of all the values
that are closest to each bin and inside the target area.
Parameters
----------
fill_value : float (default: np.nan)
Fill value for missing data
mask_all_nans : boolean (default: False)
Mask all locations with all-NaN values
"""

def compute(self, data, mask_all_nan=False, **kwargs):
"""Call the resampling."""
LOG.debug("Resampling %s", str(data.name))
results = []
if data.ndim == 3:
for i in range(data.shape[0]):
res = self.resampler.get_sum(data[i, :, :],
mask_all_nan=mask_all_nan)
results.append(res)
else:
res = self.resampler.get_sum(data, mask_all_nan=mask_all_nan)
results.append(res)

return da.stack(results)


class BucketCount(BucketResamplerBase):
"""Class for bucket resampling which implements hit-counting.
This resampler calculates the number of occurences of the input
data closest to each bin and inside the target area.
"""

def compute(self, data, **kwargs):
"""Call the resampling."""
LOG.debug("Resampling %s", str(data.name))
results = []
if data.ndim == 3:
for i in range(data.shape[0]):
res = self.resampler.get_count()
results.append(res)
else:
res = self.resampler.get_count()
results.append(res)

return da.stack(results)


class BucketFraction(BucketResamplerBase):
"""Class for bucket resampling to compute category fractions
This resampler calculates the fraction of occurences of the input
data per category.
"""

def compute(self, data, fill_value=np.nan, categories=None, **kwargs):
"""Call the resampling."""
LOG.debug("Resampling %s", str(data.name))
if data.ndim > 2:
raise ValueError("BucketFraction not implemented for 3D datasets")

result = self.resampler.get_fractions(data, categories=categories,
fill_value=fill_value)

return result


RESAMPLERS = {"kd_tree": KDTreeResampler,
"nearest": KDTreeResampler,
"ewa": EWAResampler,
"bilinear": BilinearResampler,
"native": NativeResampler,
"bucket_avg": BucketAvg,
"bucket_sum": BucketSum,
"bucket_count": BucketCount,
"bucket_fraction": BucketFraction,
}


Expand Down
Loading

0 comments on commit 78d9aa7

Please sign in to comment.