From b55adafb15cebccbedb33346c67df8250f00fbfa Mon Sep 17 00:00:00 2001 From: Huite Bootsma Date: Fri, 23 Aug 2024 17:48:53 +0200 Subject: [PATCH] Enable broadcasting/lazy evaluation in interpolate_na and laplace_interpolate. Fixes #292 Also changes dims into a set, consistent with xarray (future) behavior. --- docs/changelog.rst | 4 +- tests/test_interpolate.py | 14 +++---- tests/test_partitioning.py | 2 +- tests/test_ugrid1d.py | 2 +- tests/test_ugrid2d.py | 20 +++++++++- tests/test_ugrid_dataset.py | 22 +++++++++++ xugrid/core/dataarray_accessor.py | 64 ++++++++++++++++++++----------- xugrid/plot/plot.py | 2 +- xugrid/ugrid/interpolate.py | 30 +++++++++++++-- xugrid/ugrid/partitioning.py | 2 +- xugrid/ugrid/ugrid1d.py | 9 +++-- xugrid/ugrid/ugrid2d.py | 12 +++--- xugrid/ugrid/ugridbase.py | 16 ++++++-- 13 files changed, 145 insertions(+), 54 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index a59da3c34..6acdf3766 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,7 +24,9 @@ Added - :meth:`xugrid.UgridDataArrayAccessor.interpolate_na` has been added to fill missing data. Currently, the only supported method is ``"nearest"``. - :attr:`xugrid.Ugrid1.dims` and :attr:`xugrid.Ugrid2.dims` have been added to - return a tuple of the UGRID dimensions. + return a set of the UGRID dimensions. +- :meth:`xugrid.UgridDataArrayAccessor.laplace_interpolate` now uses broadcasts + over non-UGRID dimensions and support lazy evaluation. Changed ~~~~~~~ diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index cf46dd74a..95d5dcf3b 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -37,17 +37,17 @@ def test_laplace_interpolate(): data = np.array([1.0, np.nan, np.nan, np.nan, 5.0]) with pytest.raises(ValueError, match="connectivity is not a square matrix"): con = sparse.coo_matrix(coo_content, shape=(4, 5)).tocsr() - interpolate.laplace_interpolate(con, data, use_weights=False) + interpolate.laplace_interpolate(data, con, use_weights=False) expected = np.arange(1.0, 6.0) con = sparse.coo_matrix(coo_content, shape=(5, 5)).tocsr() actual = interpolate.laplace_interpolate( - con, data, use_weights=False, direct_solve=True + data, con, use_weights=False, direct_solve=True ) assert np.allclose(actual, expected) actual = interpolate.laplace_interpolate( - con, data, use_weights=False, direct_solve=False + data, con, use_weights=False, direct_solve=False ) assert np.allclose(actual, expected) @@ -57,13 +57,11 @@ def test_nearest_interpolate(): y = np.zeros_like(x) coordinates = np.column_stack((x, y)) data = np.array([0.0, np.nan, np.nan, np.nan, 4.0]) - actual = interpolate.nearest_interpolate(coordinates, data, np.inf) + actual = interpolate.nearest_interpolate(data, coordinates, np.inf) assert np.allclose(actual, np.array([0.0, 0.0, 0.0, 4.0, 4.0])) - actual = interpolate.nearest_interpolate(coordinates, data, 1.1) + actual = interpolate.nearest_interpolate(data, coordinates, 1.1) assert np.allclose(actual, np.array([0.0, 0.0, np.nan, 4.0, 4.0]), equal_nan=True) with pytest.raises(ValueError, match="All values are NA."): - interpolate.nearest_interpolate( - coordinates, data=np.full_like(data, np.nan), max_distance=np.inf - ) + interpolate.nearest_interpolate(np.full_like(data, np.nan), coordinates, np.inf) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index d1e93639d..632d684de 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -41,7 +41,7 @@ def test_labels_to_indices(): def test_single_ugrid_chunk(): grid = generate_mesh_2d(3, 3) - ugrid_dims = set(grid.dims) + ugrid_dims = grid.dims da = xr.DataArray(np.ones(grid.n_face), dims=(grid.face_dimension,)) assert pt.single_ugrid_chunk(da, ugrid_dims) is da diff --git a/tests/test_ugrid1d.py b/tests/test_ugrid1d.py index 362523737..56fe04ba7 100644 --- a/tests/test_ugrid1d.py +++ b/tests/test_ugrid1d.py @@ -273,7 +273,7 @@ def test_dimensions(): grid = grid1d() assert grid.node_dimension == f"{NAME}_nNodes" assert grid.edge_dimension == f"{NAME}_nEdges" - assert grid.dims == (f"{NAME}_nNodes", f"{NAME}_nEdges") + assert grid.dims == {f"{NAME}_nNodes", f"{NAME}_nEdges"} assert grid.sizes == { f"{NAME}_nNodes": 3, f"{NAME}_nEdges": 2, diff --git a/tests/test_ugrid2d.py b/tests/test_ugrid2d.py index 16a1a31be..ee787f93a 100644 --- a/tests/test_ugrid2d.py +++ b/tests/test_ugrid2d.py @@ -264,6 +264,22 @@ def check_attrs(ds): check_attrs(ds) +def test_find_ugrid_dim(): + grid = grid2d() + da = xr.DataArray(data=np.ones((grid.n_face,)), dims=[grid.face_dimension]) + assert grid.find_ugrid_dim(da) == grid.face_dimension + + weird = xr.DataArray( + data=np.ones((grid.n_face, grid.n_node)), + dims=[grid.face_dimension, grid.node_dimension], + ) + with pytest.raises( + ValueError, + match="UgridDataArray should contain exactly one of the UGRID dimension", + ): + grid.find_ugrid_dim(weird) + + def test_ugrid2d_set_node_coords(): grid = grid2d() ds = xr.Dataset() @@ -449,11 +465,11 @@ def test_dimensions(): assert grid.node_dimension == f"{NAME}_nNodes" assert grid.edge_dimension == f"{NAME}_nEdges" assert grid.face_dimension == f"{NAME}_nFaces" - assert grid.dims == ( + assert grid.dims == { f"{NAME}_nNodes", f"{NAME}_nEdges", f"{NAME}_nFaces", - ) + } assert grid.sizes == { f"{NAME}_nNodes": 7, f"{NAME}_nEdges": 10, diff --git a/tests/test_ugrid_dataset.py b/tests/test_ugrid_dataset.py index 865c194c1..39add6b7b 100644 --- a/tests/test_ugrid_dataset.py +++ b/tests/test_ugrid_dataset.py @@ -1,5 +1,6 @@ import warnings +import dask import geopandas as gpd import numpy as np import pandas as pd @@ -387,6 +388,27 @@ def test_laplace_interpolate(self): assert isinstance(actual, xugrid.UgridDataArray) assert np.allclose(actual, 1.0) + def test_broadcasted_laplace_interpolate(self): + uda2 = self.uda.copy() + uda2.obj[:-2] = np.nan + multiplier = xr.DataArray( + np.ones((3, 2)), + coords={"time": [0, 1, 2], "layer": [1, 2]}, + dims=("time", "layer"), + ) + nd_uda2 = uda2 * multiplier + actual = nd_uda2.ugrid.laplace_interpolate(direct_solve=True) + assert isinstance(actual, xugrid.UgridDataArray) + assert np.allclose(actual, 1.0) + assert set(actual.dims) == set(nd_uda2.dims) + + # Test delayed evaluation too. + nd_uda2 = uda2 * multiplier.chunk({"time": 1}) + actual = nd_uda2.ugrid.laplace_interpolate(direct_solve=True) + assert isinstance(actual, xugrid.UgridDataArray) + assert set(actual.dims) == set(nd_uda2.dims) + assert isinstance(actual.data, dask.array.Array) + def test_to_dataset(self): uda2 = self.uda.copy() uda2.ugrid.obj.name = "test" diff --git a/xugrid/core/dataarray_accessor.py b/xugrid/core/dataarray_accessor.py index 5e32fb9da..847c3f09b 100644 --- a/xugrid/core/dataarray_accessor.py +++ b/xugrid/core/dataarray_accessor.py @@ -10,7 +10,11 @@ from xugrid.core.wrap import UgridDataArray, UgridDataset from xugrid.plot.plot import _PlotMethods from xugrid.ugrid import connectivity -from xugrid.ugrid.interpolate import laplace_interpolate, nearest_interpolate +from xugrid.ugrid.interpolate import ( + interpolate_na_helper, + laplace_interpolate, + nearest_interpolate, +) from xugrid.ugrid.ugrid1d import Ugrid1d from xugrid.ugrid.ugrid2d import Ugrid2d from xugrid.ugrid.ugridbase import UgridType @@ -587,6 +591,9 @@ def interpolate_na( """ Fill in NaNs by interpolating. + This function automatically finds the UGRID dimension and broadcasts + over the other dimensions. + Parameters ---------- method: str, default is "nearest" @@ -607,13 +614,17 @@ def interpolate_na( grid = self.grid da = self.obj - - filled = nearest_interpolate( - coordinates=grid.get_coordinates(dim=da.dims[0]), - data=da.to_numpy(), - max_distance=max_distance, + ugrid_dim = grid.find_ugrid_dim(da) + + da_filled = interpolate_na_helper( + da, + ugrid_dim=ugrid_dim, + func=nearest_interpolate, + kwargs={ + "coordinates": grid.get_coordinates(ugrid_dim), + "max_distance": max_distance, + }, ) - da_filled = da.copy(data=filled) return UgridDataArray(da_filled, grid) def laplace_interpolate( @@ -629,6 +640,9 @@ def laplace_interpolate( """ Fill in NaNs by using Laplace interpolation. + This function automatically finds the UGRID dimension and broadcasts + over the other dimensions. + This solves Laplace's equation where where there is no data, with data values functioning as fixed potential boundary conditions. @@ -669,25 +683,29 @@ def laplace_interpolate( """ grid = self.grid da = self.obj - if len(da.dims) > 1: - # TODO: apply ufunc - raise NotImplementedError - if da.dims[0] == grid.edge_dimension: + + grid = self.grid + da = self.obj + ugrid_dim = grid.find_ugrid_dim(da) + if ugrid_dim == grid.edge_dimension: raise ValueError("Laplace interpolation along edges is not allowed.") - connectivity = grid.get_connectivity_matrix(da.dims[0], xy_weights=xy_weights) - filled = laplace_interpolate( - connectivity=connectivity, - data=da.to_numpy(), - use_weights=xy_weights, - direct_solve=direct_solve, - delta=delta, - relax=relax, - rtol=rtol, - atol=atol, - maxiter=maxiter, + connectivity = grid.get_connectivity_matrix(ugrid_dim, xy_weights=xy_weights) + da_filled = interpolate_na_helper( + da, + ugrid_dim, + func=laplace_interpolate, + kwargs={ + "connectivity": connectivity, + "use_weights": xy_weights, + "direct_solve": direct_solve, + "delta": delta, + "relax": relax, + "rtol": rtol, + "atol": atol, + "maxiter": maxiter, + }, ) - da_filled = da.copy(data=filled) return UgridDataArray(da_filled, grid) def to_dataset(self, optional_attributes: bool = False): diff --git a/xugrid/plot/plot.py b/xugrid/plot/plot.py index cc4e38253..cc6a425ba 100644 --- a/xugrid/plot/plot.py +++ b/xugrid/plot/plot.py @@ -625,7 +625,7 @@ def __init__(self, obj): darray = obj.obj grid = obj.grid - invalid = set(darray.dims) - set(grid.dims) + invalid = set(darray.dims) - grid.dims if invalid: raise ValueError( f"UgridDataArray contains non-topology dimensions: {invalid}.\n" diff --git a/xugrid/ugrid/interpolate.py b/xugrid/ugrid/interpolate.py index b660b7848..736245d42 100644 --- a/xugrid/ugrid/interpolate.py +++ b/xugrid/ugrid/interpolate.py @@ -1,10 +1,11 @@ from __future__ import annotations import warnings -from typing import NamedTuple, Tuple +from typing import Any, Callable, Dict, NamedTuple, Tuple import numba as nb import numpy as np +import xarray as xr from scipy import sparse from scipy.spatial import KDTree @@ -197,8 +198,8 @@ def __repr__(self) -> str: def laplace_interpolate( - connectivity: sparse.csr_matrix, data: FloatArray, + connectivity: sparse.csr_matrix, use_weights: bool, direct_solve: bool = False, delta=0.0, @@ -218,9 +219,9 @@ def laplace_interpolate( Parameters ---------- + data: ndarray of floats with shape ``(n,)`` connectivity: scipy.sparse.csr_matrix with shape ``(n, n)`` Sparse connectivity matrix containing ``n_nonzero`` indices and weight values. - data: ndarray of floats with shape ``(n,)`` use_weights: bool, default False. Wether to use the data attribute of the connectivity matrix as coefficients. If ``False``, defaults to uniform coefficients of 1. @@ -310,8 +311,8 @@ def laplace_interpolate( def nearest_interpolate( - coordinates: FloatArray, data: FloatArray, + coordinates: FloatArray, max_distance: float, ) -> FloatArray: isnull = np.isnan(data) @@ -337,3 +338,24 @@ def nearest_interpolate( out = data.copy() out[i_target] = data[i_source[index]] return out + + +def interpolate_na_helper( + da: xr.DataArray, + ugrid_dim: str, + func: Callable, + kwargs: Dict[str, Any], +): + """Use apply ufunc to broadcast over the non UGRID dims.""" + da_filled = xr.apply_ufunc( + func, + da, + input_core_dims=[[ugrid_dim]], + output_core_dims=[[ugrid_dim]], + vectorize=True, + kwargs=kwargs, + dask="parallelized", + keep_attrs=True, + output_dtypes=[da.dtype], + ) + return da_filled diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 5d09fade4..e2a64e038 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -156,7 +156,7 @@ def validate_partition_topology(grouped: defaultdict[str, UgridType]) -> None: f"same type, received: {types}" ) - griddims = list({grid.dims for grid in grids}) + griddims = list({tuple(grid.dims) for grid in grids}) if len(griddims) > 1: raise ValueError( f"Dimension names on UGRID topology {name} do not match " diff --git a/xugrid/ugrid/ugrid1d.py b/xugrid/ugrid/ugrid1d.py index 61bd9b94b..16ab0dcd6 100644 --- a/xugrid/ugrid/ugrid1d.py +++ b/xugrid/ugrid/ugrid1d.py @@ -263,9 +263,8 @@ def core_dimension(self): @property def dims(self): - """Tuple of UGRID dimension names: node dimension, edge dimension.""" - # Tuple to preserve order, unlike set. - return (self.node_dimension, self.edge_dimension) + """Set of UGRID dimension names: node dimension, edge dimension.""" + return {self.node_dimension, self.edge_dimension} @property def sizes(self): @@ -466,7 +465,9 @@ def isel(self, indexers=None, return_index=False, **indexers_kwargs): ) indexers = {k: as_pandas_index(v, self.sizes[k]) for k, v in indexers.items()} - nodedim, edgedim = self.dims + nodedim = self.node_dimension + edgedim = self.edge_dimension + edge_index = {} if nodedim in indexers: node_index = indexers[nodedim] diff --git a/xugrid/ugrid/ugrid2d.py b/xugrid/ugrid/ugrid2d.py index 5abda222d..1ea8c936c 100644 --- a/xugrid/ugrid/ugrid2d.py +++ b/xugrid/ugrid/ugrid2d.py @@ -415,13 +415,12 @@ def core_dimension(self): @property def dims(self): - """Tuple of UGRID dimension names: node dimension, edge dimension, face_dimension.""" - # Tuple to preserve order, unlike set. - return ( + """Set of UGRID dimension names: node dimension, edge dimension, face_dimension.""" + return { self.node_dimension, self.edge_dimension, self.face_dimension, - ) + } @property def sizes(self): @@ -1202,7 +1201,10 @@ def isel(self, indexers=None, return_index=False, **indexers_kwargs): ) indexers = {k: as_pandas_index(v, self.sizes[k]) for k, v in indexers.items()} - nodedim, edgedim, facedim = self.dims + nodedim = self.node_dimension + edgedim = self.edge_dimension + facedim = self.face_dimension + face_index = {} if nodedim in indexers: node_index = indexers[nodedim] diff --git a/xugrid/ugrid/ugridbase.py b/xugrid/ugrid/ugridbase.py index 21a1a11e6..8a7295e6a 100644 --- a/xugrid/ugrid/ugridbase.py +++ b/xugrid/ugrid/ugridbase.py @@ -2,7 +2,7 @@ import copy import warnings from itertools import chain -from typing import Dict, Tuple, Type, Union, cast +from typing import Dict, Set, Tuple, Type, Union, cast import numpy as np import pandas as pd @@ -75,7 +75,7 @@ def align(obj, grids, old_indexes): # Group the indexers by grid new_grids = [] for grid in grids: - ugrid_dims = set(grid.dims).intersection(new_indexes) + ugrid_dims = grid.dims.intersection(new_indexes) ugrid_indexes = {dim: new_indexes[dim] for dim in ugrid_dims} newgrid, indexers = grid.isel(indexers=ugrid_indexes, return_index=True) indexers = { @@ -99,7 +99,7 @@ def core_dimension(self): @property @abc.abstractmethod - def dims(self) -> Tuple[str]: + def dims(self) -> Set[str]: pass @property @@ -502,6 +502,16 @@ def _postcheck(self, indexers, finalized_indexers): ) return + def find_ugrid_dim(self, obj: Union[xr.DataArray, xr.Dataset]): + """Find the UGRID dimension that is present in the object.""" + ugrid_dims = self.dims.intersection(obj.dims) + if len(ugrid_dims) != 1: + raise ValueError( + "UgridDataArray should contain exactly one of the UGRID " + f"dimensions: {self.dims}" + ) + return ugrid_dims.pop() + def set_node_coords( self, node_x: str,