Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge branches 'docs/populate-docstrings' and 'docs/populate-docstrin…
Browse files Browse the repository at this point in the history
…gs' of https://github.com/ecmwf/anemoi-models into docs/populate-docstrings
  • Loading branch information
JesperDramsch committed Jun 5, 2024
2 parents 17c2edd + d72a86a commit 9484eef
Show file tree
Hide file tree
Showing 47 changed files with 2,869 additions and 174 deletions.
60 changes: 59 additions & 1 deletion docs/modules/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,68 @@
Distributed
#############

*******
graph
*******

.. automodule:: anemoi.models.distributed.graph
:members:
:no-undoc-members:
:show-inheritance:

************
khop_edges
************

.. automodule:: anemoi.models.distributed.khop_edges
:members:
:no-undoc-members:
:show-inheritance:

..
*************
..
primitives
..
*************
..
.. automodule:: anemoi.models.distributed.primitives
..
:members:
..
:no-undoc-members:
..
:show-inheritance:
********
shapes
********

.. automodule:: anemoi.models.distributed.shapes
:members:
:no-undoc-members:
:show-inheritance:

*************
transformer
*************

.. automodule:: anemoi.models.distributed.transformer
:members:
:no-undoc-members:
:show-inheritance:

*******
utils
*******

.. automodule:: anemoi.models.distributed
.. automodule:: anemoi.models.distributed.utils
:members:
:no-undoc-members:
:show-inheritance:
27 changes: 21 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
Expand All @@ -10,12 +9,13 @@
# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/

[build-system]
requires = ["setuptools>=60", "setuptools-scm>=8.0"]
requires = ["setuptools>=61", "setuptools-scm>=8.0"]
build-backend = "setuptools.build_meta"

[project]
description = "A package to hold various functions to support training of ML models."
name = "anemoi-models"

readme = "README.md"
dynamic = ["version"]
license = { file = "LICENSE" }
requires-python = ">=3.9"
Expand All @@ -39,19 +39,34 @@ classifiers = [
"Operating System :: OS Independent",
]

dependencies = []
dependencies = [
"torch==2.3",
"torch-geometric==2.4",
"einops==0.6.1",
"hydra-core==1.3",
"anemoi-datasets==0.2.1",
"anemoi-utils==0.1.9",
]

[project.optional-dependencies]


docs = [
# For building the documentation
"sphinx", "sphinx_rtd_theme", "nbsphinx", "pandoc", "sphinx_argparse"
]

all = []

dev = []
tests = ["pytest", "hypothesis"]

dev = [
"sphinx",
"sphinx_rtd_theme",
"nbsphinx",
"pandoc",
"pytest",
"hypothesis",
]

[project.urls]
Homepage = "https://github.com/ecmwf/anemoi-models/"
Expand Down
74 changes: 74 additions & 0 deletions src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import operator

import yaml
from omegaconf import OmegaConf

from anemoi.models.data_indices.index import BaseIndex
from anemoi.models.data_indices.index import DataIndex
from anemoi.models.data_indices.index import ModelIndex
from anemoi.models.data_indices.tensor import BaseTensorIndex
from anemoi.models.data_indices.tensor import InputTensorIndex
from anemoi.models.data_indices.tensor import OutputTensorIndex


class IndexCollection:
"""Collection of data and model indices."""

def __init__(self, config, name_to_index) -> None:
self.config = OmegaConf.to_container(config, resolve=True)

self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True)
self.diagnostic = (
[] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True)
)

assert set(self.diagnostic).isdisjoint(self.forcing), (
f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ",
"Please drop them at a dataset-level to exclude them from the training data.",
)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
name_to_index_model_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic)
}
name_to_index_model_output = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing)
}

self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index)
self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output)

def __repr__(self) -> str:
return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})"

def __eq__(self, other):
if not isinstance(other, IndexCollection):
# don't attempt to compare against unrelated types
return NotImplemented

return self.model == other.model and self.data == other.data

def __getitem__(self, key):
return getattr(self, key)

def todict(self):
return {
"data": self.data.todict(),
"model": self.model.todict(),
}

@staticmethod
def representer(dumper, data):
return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data))


for cls in [BaseTensorIndex, InputTensorIndex, OutputTensorIndex, BaseIndex, DataIndex, ModelIndex, IndexCollection]:
yaml.add_representer(cls, cls.representer)
93 changes: 93 additions & 0 deletions src/anemoi/models/data_indices/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from anemoi.models.data_indices.tensor import InputTensorIndex
from anemoi.models.data_indices.tensor import OutputTensorIndex


class BaseIndex:
"""Base class for data and model indices."""

def __init__(self) -> None:
self.input = NotImplementedError
self.output = NotImplementedError

def __eq__(self, other):
if not isinstance(other, BaseIndex):
# don't attempt to compare against unrelated types
return NotImplemented

return self.input == other.input and self.output == other.output

def __repr__(self) -> str:
return f"{self.__class__.__name__}(input={self.input}, output={self.output})"

def __getitem__(self, key):
return getattr(self, key)

def todict(self):
return {
"input": self.input.todict(),
"output": self.output.todict(),
}

@staticmethod
def representer(dumper, data):
return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data))


class DataIndex(BaseIndex):
"""Indexing for data variables."""

def __init__(self, diagnostic, forcing, name_to_index) -> None:
self._diagnostic = diagnostic
self._forcing = forcing
self._name_to_index = name_to_index
self.input = InputTensorIndex(
includes=forcing,
excludes=diagnostic,
name_to_index=name_to_index,
)

self.output = OutputTensorIndex(
includes=diagnostic,
excludes=forcing,
name_to_index=name_to_index,
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, name_to_index={self._name_to_index})"


class ModelIndex(BaseIndex):
"""Indexing for model variables."""

def __init__(self, diagnostic, forcing, name_to_index_model_input, name_to_index_model_output) -> None:
self._diagnostic = diagnostic
self._forcing = forcing
self._name_to_index_model_input = name_to_index_model_input
self._name_to_index_model_output = name_to_index_model_output
self.input = InputTensorIndex(
includes=forcing,
excludes=[],
name_to_index=name_to_index_model_input,
)

self.output = OutputTensorIndex(
includes=diagnostic,
excludes=[],
name_to_index=name_to_index_model_output,
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, "
f"name_to_index_model_input={self._name_to_index_model_input}, "
f"name_to_index_model_output={self._name_to_index_model_output})"
)
Loading

0 comments on commit 9484eef

Please sign in to comment.