Skip to content

Commit

Permalink
Support start_index=1. Fixes #119
Browse files Browse the repository at this point in the history
  • Loading branch information
Huite committed Sep 6, 2024
1 parent 178fd77 commit 9becbac
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 35 deletions.
18 changes: 18 additions & 0 deletions tests/test_ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,24 @@ def test_ugrid2d_dataset_no_mutation():
assert ds.identical(reference)


@pytest.mark.parametrize("edge_start_index", [0, 1])
@pytest.mark.parametrize("face_start_index", [0, 1])
def test_ugrid2d_from_dataset__different_start_index(
face_start_index, edge_start_index
):
grid = grid2d()
ds = grid.to_dataset(optional_attributes=True) # include edge_nodes
faces = ds["mesh2d_face_nodes"].to_numpy()
faces[faces != -1] += face_start_index
ds["mesh2d_face_nodes"].attrs["start_index"] = face_start_index
ds["mesh2d_edge_nodes"] += edge_start_index
ds["mesh2d_edge_nodes"].attrs["start_index"] = edge_start_index
new = xugrid.Ugrid2d.from_dataset(ds)
assert new.start_index == face_start_index
assert np.array_equal(new.face_node_connectivity, grid.face_node_connectivity)
assert np.array_equal(new.edge_node_connectivity, grid.edge_node_connectivity)


def test_ugrid2d_from_meshkernel():
# Setup a meshkernel Mesh2d mimick
class Mesh2d(NamedTuple):
Expand Down
16 changes: 13 additions & 3 deletions tests/test_ugrid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,11 +1197,12 @@ def test_alternative_fill_value_start_index():
# Check internal fill value. Should be FILL_VALUE
grid = uds.ugrid.grid
assert grid.face_node_connectivity.dtype == "int64"
assert grid.start_index == 1
assert grid.fill_value == -999
assert (grid.face_node_connectivity != -999).all()
gridds = grid.to_dataset()
# Should be set back to the origina fill value.
faces = gridds["mesh2d_face_nodes"]
assert (faces != xugrid.constants.FILL_VALUE).all()
assert faces.attrs["start_index"] == 1
uniq = np.unique(faces)
assert uniq[0] == -999
Expand All @@ -1210,20 +1211,29 @@ def test_alternative_fill_value_start_index():
# And similarly for the UgridAccessors.
ds = uds.ugrid.to_dataset()
faces = ds["mesh2d_face_nodes"]
assert (faces != xugrid.constants.FILL_VALUE).all()
assert faces.attrs["start_index"] == 1
uniq = np.unique(faces)
assert uniq[0] == -999
assert uniq[1] == 1

ds_uda = uds["mesh2d_facevar"].ugrid.to_dataset()
faces = ds_uda["mesh2d_face_nodes"]
assert (faces != xugrid.constants.FILL_VALUE).all()
assert faces.attrs["start_index"] == 1
uniq = np.unique(faces)
assert uniq[0] == -999
assert uniq[1] == 1

# Alternative value
grid.start_index = 0
grid.fill_value = -2
gridds = grid.to_dataset()
# Should be set back to the origina fill value.
faces = gridds["mesh2d_face_nodes"]
assert faces.attrs["start_index"] == 0
uniq = np.unique(faces)
assert uniq[0] == -2
assert uniq[1] == 0


def test_fm_facenodeconnectivity_fillvalue():
"""
Expand Down
14 changes: 9 additions & 5 deletions xugrid/ugrid/ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ def from_dataset(cls, dataset: xr.Dataset, topology: str = None):
node_y_coordinates = ds[y_index].astype(FloatDType).to_numpy()

edge_nodes = connectivity["edge_node_connectivity"]
fill_value = ds[edge_nodes].encoding.get("_FillValue", -1)
start_index = ds[edge_nodes].attrs.get("start_index", 0)
edge_node_connectivity = cls._prepare_connectivity(
ds[edge_nodes], dtype=IntDType
ds[edge_nodes], fill_value, dtype=IntDType
).to_numpy()

indexes["node_x"] = x_index
Expand All @@ -159,13 +161,14 @@ def from_dataset(cls, dataset: xr.Dataset, topology: str = None):
return cls(
node_x_coordinates,
node_y_coordinates,
FILL_VALUE,
fill_value,
edge_node_connectivity,
name=topology,
dataset=dataset[ugrid_vars],
indexes=indexes,
projected=projected,
crs=None,
start_index=start_index,
)

def _clear_geometry_properties(self):
Expand Down Expand Up @@ -246,7 +249,7 @@ def to_dataset(

dataset = xr.Dataset(data_vars, attrs=attrs)
if self._dataset:
dataset.update(self._dataset)
dataset = dataset.merge(self._dataset, compat="override")
if other is not None:
dataset = dataset.merge(other)
if node_x not in dataset or node_y not in dataset:
Expand Down Expand Up @@ -574,17 +577,18 @@ def topology_subset(
new_edges = connectivity.renumber(edge_subset)
node_x = self.node_x[node_index]
node_y = self.node_y[node_index]
grid = self.__class__(
grid = Ugrid1d(
node_x,
node_y,
self.fill_value,
FILL_VALUE,
new_edges,
name=self.name,
indexes=self._indexes,
projected=self.projected,
crs=self.crs,
attrs=self._attrs,
)
self._propagate_properties(grid)
if return_index:
indexes = {
self.node_dimension: pd.Index(node_index),
Expand Down
52 changes: 37 additions & 15 deletions xugrid/ugrid/ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ def __init__(
)

# Ensure the fill value is FILL_VALUE (-1) and the array is 0-based.
is_fill = self.face_node_connectivity == self.fill_value
if self.start_index != 0:
self.face_node_connectivity[~is_fill] -= self.start_index
if self.fill_value != FILL_VALUE:
self.face_node_connectivity[is_fill] = FILL_VALUE
if self.fill_value != -1 or self.start_index != 0:
is_fill = self.face_node_connectivity == self.fill_value
if self.start_index != 0:
self.face_node_connectivity[~is_fill] -= self.start_index
if self.fill_value != FILL_VALUE:
self.face_node_connectivity[is_fill] = FILL_VALUE

# TODO: do this in validation instead. While UGRID conventions demand it,
# where does it go wrong?
Expand Down Expand Up @@ -156,7 +157,9 @@ def __init__(
self._edge_x = None
self._edge_y = None
# Connectivity
self.edge_node_connectivity = edge_node_connectivity
self._edge_node_connectivity = edge_node_connectivity
if self._edge_node_connectivity is not None:
self._edge_node_connectivity -= self.start_index
self._edge_face_connectivity = None
self._node_node_connectivity = None
self._node_edge_connectivity = None
Expand Down Expand Up @@ -290,15 +293,23 @@ def from_dataset(cls, dataset: xr.Dataset, topology: str = None):

face_nodes = connectivity["face_node_connectivity"]
fill_value = ds[face_nodes].encoding.get("_FillValue", -1)
start_index = ds[face_nodes].attrs.get("start_index", 0)
face_node_connectivity = cls._prepare_connectivity(
ds[face_nodes], dtype=IntDType
ds[face_nodes], fill_value, dtype=IntDType
).to_numpy()

edge_nodes = connectivity.get("edge_node_connectivity")
if edge_nodes:
edge_node_connectivity = cls._prepare_connectivity(
ds[edge_nodes], dtype=IntDType
ds[edge_nodes], fill_value, dtype=IntDType
).to_numpy()
# Make sure the single passed start index is valid for both
# connectivity arrays.
edge_start_index = ds[edge_nodes].attrs.get("start_index", 0)
if edge_start_index != start_index:
# start_index = 1, edge_start_index = 0, then add one
# start_index = 0, edge_start_index = 1, then subtract one
edge_node_connectivity += start_index - edge_start_index
else:
edge_node_connectivity = None

Expand All @@ -317,6 +328,7 @@ def from_dataset(cls, dataset: xr.Dataset, topology: str = None):
indexes=indexes,
projected=projected,
crs=None,
start_index=start_index,
)

def _get_name_and_attrs(self, name: str):
Expand Down Expand Up @@ -1137,10 +1149,10 @@ def topology_subset(
edge_subset = self.edge_node_connectivity[edge_index]
new_edges = connectivity.renumber(edge_subset)

grid = self.__class__(
grid = Ugrid2d(
node_x,
node_y,
self.fill_value,
FILL_VALUE,
new_faces,
name=self.name,
edge_node_connectivity=new_edges,
Expand All @@ -1149,6 +1161,7 @@ def topology_subset(
crs=self.crs,
attrs=self._attrs,
)
self._propagate_properties(grid)
if return_index:
indexes = {
self.node_dimension: pd.Index(node_index),
Expand Down Expand Up @@ -1631,6 +1644,8 @@ def merge_partitions(
crs=grid.crs,
attrs=grid._attrs,
)
# Maintain fill_value, start_index
grid._propagate_properties(merged_grid)
return merged_grid, indexes

def to_periodic(self, obj=None):
Expand Down Expand Up @@ -1683,14 +1698,15 @@ def to_periodic(self, obj=None):
node_x=new_xy[:, 0],
node_y=new_xy[:, 1],
face_node_connectivity=new_faces,
fill_value=self.fill_value,
fill_value=FILL_VALUE,
name=self.name,
edge_node_connectivity=new_edges,
indexes=self._indexes,
projected=self.projected,
crs=self.crs,
attrs=self.attrs,
)
self._propagate_properties(new)

if obj is not None:
indexes = {
Expand Down Expand Up @@ -1743,14 +1759,15 @@ def to_nonperiodic(self, xmax: float, obj=None):
node_x=np.concatenate((self.node_x, new_x)),
node_y=np.concatenate((self.node_y, new_y)),
face_node_connectivity=new_faces,
fill_value=self.fill_value,
fill_value=FILL_VALUE,
name=self.name,
edge_node_connectivity=None,
indexes=self._indexes,
projected=self.projected,
crs=self.crs,
attrs=self.attrs,
)
self._propagate_properties(new)

edge_index = None
if self._edge_node_connectivity is not None:
Expand Down Expand Up @@ -1863,7 +1880,9 @@ def triangulate(self):
triangles: Ugrid2d
"""
triangles, _ = connectivity.triangulate(self.face_node_connectivity)
return Ugrid2d(self.node_x, self.node_y, FILL_VALUE, triangles)
grid = Ugrid2d(self.node_x, self.node_y, FILL_VALUE, triangles)
self._propagate_properties(grid)
return grid

def _tesselate_voronoi(self, centroids, add_exterior, add_vertices, skip_concave):
if add_exterior:
Expand All @@ -1883,7 +1902,9 @@ def _tesselate_voronoi(self, centroids, add_exterior, add_vertices, skip_concave
add_vertices,
skip_concave,
)
return Ugrid2d(vertices[:, 0], vertices[:, 1], FILL_VALUE, faces)
grid = Ugrid2d(vertices[:, 0], vertices[:, 1], FILL_VALUE, faces)
self._propagate_properties(grid)
return grid

def tesselate_centroidal_voronoi(
self, add_exterior=True, add_vertices=True, skip_concave=False
Expand Down Expand Up @@ -1951,9 +1972,10 @@ def reverse_cuthill_mckee(self, dimension=None):
reordered_grid = Ugrid2d(
self.node_x,
self.node_y,
self.fill_value,
FILL_VALUE,
self.face_node_connectivity[reordering],
)
self._propagate_properties(reordered_grid)
return reordered_grid, reordering

def refine_polygon(
Expand Down
29 changes: 17 additions & 12 deletions xugrid/ugrid/ugridbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ def rename(self, name: str, return_name_dict: bool = False):
else:
return new

def _propagate_properties(self, other) -> None:
other.start_index = self.start_index
other.fill_value = self.fill_value

@staticmethod
def _single_topology(dataset: xr.Dataset):
topologies = dataset.ugrid_roles.topology
Expand Down Expand Up @@ -475,11 +479,14 @@ def edge_bounds(self) -> FloatArray:
)

@staticmethod
def _prepare_connectivity(da: xr.DataArray, dtype: type) -> xr.DataArray:
start_index = da.attrs.get("start_index", 0)
if start_index not in (0, 1):
raise ValueError(f"start_index should be 0 or 1, received: {start_index}")

def _prepare_connectivity(
da: xr.DataArray, fill_value: Union[int, float], dtype: type
) -> xr.DataArray:
"""
Undo the work xarray does when it encounters a _FillValue for UGRID
connectivity arrays. Set an external unified value back (across all
connectivities!), and cast back to the desired dtype.
"""
data = da.to_numpy().copy()
# If xarray detects a _FillValue, it converts the array to floats and
# replaces the fill value by NaN, and moves the _FillValue to
Expand All @@ -489,12 +496,9 @@ def _prepare_connectivity(da: xr.DataArray, dtype: type) -> xr.DataArray:
else:
is_fill = np.isnan(data)
# Set the fill_value before casting: otherwise the cast may fail.
data[is_fill] = FILL_VALUE
data[is_fill] = fill_value
cast = data.astype(dtype, copy=False)

not_fill = ~is_fill
if start_index:
cast[not_fill] -= start_index
if (cast[not_fill] < 0).any():
raise ValueError("connectivity contains negative values")
return da.copy(data=cast)
Expand All @@ -504,10 +508,11 @@ def _adjust(self, connectivity: IntArray) -> IntArray:
c = connectivity.copy()
if self.start_index == 0 and self.fill_value == FILL_VALUE:
return c
if self.start_index != 0:
c = np.where(c != FILL_VALUE, c + self.start_index, c)
is_fill = c == FILL_VALUE
if self.start_index:
c[~is_fill] += self.start_index
if self.fill_value != FILL_VALUE:
c[c == FILL_VALUE] = self.fill_value
c[is_fill] = self.fill_value
return c

def _precheck(self, multi_index):
Expand Down

0 comments on commit 9becbac

Please sign in to comment.