Skip to content

Commit

Permalink
HEALPix nodes (#16)
Browse files Browse the repository at this point in the history
[feature] HEALPix node builder

Co-authored-by: theissenhelen <[email protected]>
  • Loading branch information
JPXKQX and theissenhelen authored Aug 5, 2024
1 parent eabb8f1 commit bfdac7d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added

- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere
### Changed

### Removed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
"anemoi-datasets[data]>=0.3.3",
"anemoi-utils>=0.3.6",
"h3>=3.7.6,<4",
"healpy>=1.17",
"hydra-core>=1.3",
"networkx>=3.1",
"torch>=2.2",
Expand Down
51 changes: 51 additions & 0 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,54 @@ class HexNodes(IcosahedralNodes):

def create_nodes(self) -> np.ndarray:
return create_hexagonal_nodes(self.resolutions)


class HEALPixNodes(BaseNodeBuilder):
"""Nodes from HEALPix grid.
HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere.
Attributes
----------
resolution : int
The resolution of the grid.
name : str
The name of the nodes.
Methods
-------
get_coordinates()
Get the lat-lon coordinates of the nodes.
register_nodes(graph, name)
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
Update the graph with new nodes and attributes.
"""

def __init__(self, resolution: int, name: str) -> None:
"""Initialize the HEALPixNodes builder."""
self.resolution = resolution
super().__init__(name)

assert isinstance(resolution, int), "Resolution must be an integer."
assert resolution > 0, "Resolution must be positive."

def get_coordinates(self) -> torch.Tensor:
"""Get the coordinates of the nodes.
Returns
-------
torch.Tensor of shape (N, 2)
Coordinates of the nodes.
"""
import healpy as hp

spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60
LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.")

npix = hp.nside2npix(2**self.resolution)
hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True)

return self.reshape_coords(hpxlat, hpxlon)
51 changes: 51 additions & 0 deletions tests/nodes/test_healpix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.attributes import AreaWeights
from anemoi.graphs.nodes.attributes import UniformWeights
from anemoi.graphs.nodes.builder import BaseNodeBuilder
from anemoi.graphs.nodes.builder import HEALPixNodes


@pytest.mark.parametrize("resolution", [2, 5, 7])
def test_init(resolution: int):
"""Test HEALPixNodes initialization."""
node_builder = HEALPixNodes(resolution, "test_nodes")
assert isinstance(node_builder, BaseNodeBuilder)
assert isinstance(node_builder, HEALPixNodes)


@pytest.mark.parametrize("resolution", ["2", 4.3, -7])
def test_fail_init(resolution: int):
"""Test HEALPixNodes initialization with invalid resolution."""
with pytest.raises(AssertionError):
HEALPixNodes(resolution, "test_nodes")


@pytest.mark.parametrize("resolution", [2, 5, 7])
def test_register_nodes(resolution: int):
"""Test HEALPixNodes register correctly the nodes."""
node_builder = HEALPixNodes(resolution, "test_nodes")
graph = HeteroData()

graph = node_builder.register_nodes(graph)

assert graph["test_nodes"].x is not None
assert isinstance(graph["test_nodes"].x, torch.Tensor)
assert graph["test_nodes"].x.shape[1] == 2
assert graph["test_nodes"].node_type == "HEALPixNodes"


@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights])
@pytest.mark.parametrize("resolution", [2, 5, 7])
def test_register_attributes(graph_with_nodes: HeteroData, attr_class, resolution: int):
"""Test HEALPixNodes register correctly the weights."""
node_builder = HEALPixNodes(resolution, "test_nodes")
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, config)

assert graph["test_nodes"]["test_attr"] is not None
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor)
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0]

0 comments on commit bfdac7d

Please sign in to comment.