diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a19ade0 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +CHANGELOG.md merge=union diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..a2c619f --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,6 @@ +# CODEOWNERS file + +# Protect workflow files +/.github/ @theissenhelen @jesperdramsch @gmertes +/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes +/pyproject.toml @theissenhelen @jesperdramsch @gmertes diff --git a/.github/ci-hpc-config.yml b/.github/ci-hpc-config.yml new file mode 100644 index 0000000..5c27b03 --- /dev/null +++ b/.github/ci-hpc-config.yml @@ -0,0 +1,7 @@ +build: + python: '3.10' + modules: + - ninja + python_dependencies: + - ecmwf/anemoi-utils@develop + parallel: 64 diff --git a/.github/workflows/changelog-pr-update.yml b/.github/workflows/changelog-pr-update.yml index 43acb1c..e7ed9a2 100644 --- a/.github/workflows/changelog-pr-update.yml +++ b/.github/workflows/changelog-pr-update.yml @@ -2,6 +2,9 @@ name: Check Changelog Update on PR on: pull_request: types: [assigned, opened, synchronize, reopened, labeled, unlabeled] + paths-ignore: + - .pre-commit-config.yaml + - .readthedocs.yaml jobs: Check-Changelog: name: Check Changelog Action diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5867ee0..8b2926b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: downstream-ci-hpc: name: downstream-ci-hpc if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} - uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci.yml@main + uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main with: anemoi-models: ecmwf/anemoi-models@${{ github.event.pull_request.head.sha || github.sha }} secrets: inherit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c042b1f..f3c3962 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,12 @@ repos: - id: no-commit-to-branch # Prevent committing to main / master - id: check-added-large-files # Check for large files added to git - id: check-merge-conflict # Check for files that contain merge conflict +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 # Use the ref you want to point at + hooks: + - id: python-use-type-annotations # Check for missing type annotations + - id: python-check-blanket-noqa # Check for # noqa: all + - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.8.0 hooks: @@ -34,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.3 + rev: v0.6.4 hooks: - id: ruff # Next line if for documenation cod snippets @@ -45,7 +51,7 @@ repos: - --exit-non-zero-on-fix - --preview - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint # For now, we use it. But it does not support a lot of sphinx features @@ -59,12 +65,21 @@ repos: hooks: - id: docconvert args: ["numpy"] -- repo: https://github.com/b8raoult/optional-dependencies-all - rev: "0.0.6" - hooks: - - id: optional-dependencies-all - args: ["--inplace", "--exclude-keys=dev,docs,tests", "--group=dev=all,docs,tests"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.1" + rev: "2.2.3" hooks: - id: pyproject-fmt +- repo: https://github.com/jshwi/docsig # Check docstrings against function sig + rev: v0.60.1 + hooks: + - id: docsig + args: + - --ignore-no-params # Allow docstrings without parameters + - --check-dunders # Check dunder methods + - --check-overridden # Check overridden methods + - --check-protected # Check protected methods + - --check-class # Check class docstrings + - --disable=E113 # Disable empty docstrings + - --summary # Print a summary +ci: + autoupdate_schedule: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9c8f9..5678486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,19 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD) +### Added +- Codeowners file +- Pygrep precommit hooks +- Docsig precommit hooks +- Changelog merge strategy +- configurabilty of the dropout probability in the the MultiHeadSelfAttention module +- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) + +### Changed +- Bugfixes for CI + +### Removed + ## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables ### Added diff --git a/pyproject.toml b/pyproject.toml index 05a99c4..214f82c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,20 +11,13 @@ [build-system] build-backend = "setuptools.build_meta" -requires = [ - "setuptools>=61", - "setuptools-scm>=8", -] +requires = [ "setuptools>=61", "setuptools-scm>=8" ] [project] name = "anemoi-models" description = "A package to hold various functions to support training of ML models." readme = "README.md" -keywords = [ - "ai", - "models", - "tools", -] +keywords = [ "ai", "models", "tools" ] license = { file = "LICENSE" } authors = [ @@ -47,9 +40,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] -dynamic = [ - "version", -] +dynamic = [ "version" ] dependencies = [ "anemoi-utils>=0.1.9", "einops>=0.6.1", @@ -57,19 +48,9 @@ dependencies = [ "torch>=2.2", "torch-geometric>=2.3,<2.5", ] -optional-dependencies.all = [ -] +optional-dependencies.all = [ ] -optional-dependencies.dev = [ - "hypothesis", - "nbsphinx", - "pandoc", - "pytest", - "rstfmt", - "sphinx", - "sphinx-argparse<0.5", - "sphinx-rtd-theme", -] +optional-dependencies.dev = [ "anemoi-models[all,docs,tests]" ] optional-dependencies.docs = [ "nbsphinx", @@ -80,10 +61,7 @@ optional-dependencies.docs = [ "sphinx-rtd-theme", ] -optional-dependencies.tests = [ - "hypothesis", - "pytest", -] +optional-dependencies.tests = [ "hypothesis", "pytest" ] urls.Documentation = "https://anemoi-models.readthedocs.io/" urls.Homepage = "https://github.com/ecmwf/anemoi-models/" diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 2063ad0..931e098 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -40,7 +40,7 @@ def __init__( bias: bool = False, is_causal: bool = False, window_size: Optional[int] = None, - dropout: float = 0.0, + dropout_p: float = 0.0, ): super().__init__() @@ -48,11 +48,11 @@ def __init__( embed_dim % num_heads == 0 ), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})" - self.dropout = dropout self.num_heads = num_heads self.embed_dim = embed_dim self.head_dim = embed_dim // num_heads # q k v self.window_size = (window_size, window_size) # flash attention + self.dropout_p = dropout_p self.is_causal = is_causal self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) @@ -86,15 +86,22 @@ def forward( query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) + dropout_p = self.dropout_p if self.training else 0.0 if _FLASH_ATTENTION_AVAILABLE: query, key, value = ( einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) ) - out = self.attention(query, key, value, causal=False, window_size=self.window_size) + out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p) out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") else: - out = self.attention(query, key, value, is_causal=False) # expects (batch heads grid variable) format + out = self.attention( + query, + key, + value, + is_causal=False, + dropout_p=dropout_p, + ) # expects (batch heads grid variable) format out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index ba29607..7fd3627 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -55,7 +55,15 @@ def forward( class TransformerProcessorBlock(BaseBlock): """Transformer block with MultiHeadSelfAttention and MLPs.""" - def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size: int): + def __init__( + self, + num_channels: int, + hidden_dim: int, + num_heads: int, + activation: str, + window_size: int, + dropout_p: float = 0.0, + ): super().__init__() try: @@ -72,7 +80,7 @@ def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size: window_size=window_size, bias=False, is_causal=False, - dropout=0.0, + dropout_p=dropout_p, ) self.mlp = nn.Sequential( diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py new file mode 100644 index 0000000..3791ff2 --- /dev/null +++ b/src/anemoi/models/layers/bounding.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +import torch +from torch import nn + +from anemoi.models.data_indices.tensor import InputTensorIndex + + +class BaseBounding(nn.Module, ABC): + """Abstract base class for bounding strategies. + + This class defines an interface for bounding strategies which are used to apply a specific + restriction to the predictions of a model. + """ + + def __init__( + self, + *, + variables: list[str], + name_to_index: dict, + ) -> None: + super().__init__() + + self.name_to_index = name_to_index + self.variables = variables + self.data_index = self._create_index(variables=self.variables) + + def _create_index(self, variables: list[str]) -> InputTensorIndex: + return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the bounding to the predictions. + + Parameters + ---------- + x : torch.Tensor + The tensor containing the predictions that will be bounded. + + Returns + ------- + torch.Tensor + A tensor with the bounding applied. + """ + pass + + +class ReluBounding(BaseBounding): + """Initializes the bounding with a ReLU activation / zero clamping.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) + return x + + +class HardtanhBounding(BaseBounding): + """Initializes the bounding with specified minimum and maximum values for bounding. + + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + min_val : float + The minimum value for the HardTanh activation. + max_val : float + The maximum value for the HardTanh activation. + """ + + def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None: + super().__init__(variables=variables, name_to_index=name_to_index) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = torch.nn.functional.hardtanh( + x[..., self.data_index], min_val=self.min_val, max_val=self.max_val + ) + return x + + +class FractionBounding(HardtanhBounding): + """Initializes the FractionBounding with specified parameters. + + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + min_val : float + The minimum value for the HardTanh activation. + max_val : float + The maximum value for the HardTanh activation. + total_var : str + A string representing a variable from which a secondary variable is derived. For + example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation). + """ + + def __init__( + self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str + ) -> None: + super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) + self.total_variable = self._create_index(variables=[total_var]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply the HardTanh bounding to the data_index variables + x = super().forward(x) + # Calculate the fraction of the total variable + x[..., self.data_index] *= x[..., self.total_variable] + return x diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 61dec34..87d0ac7 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -73,6 +73,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", + dropout_p: float = 0.0, ) -> None: """Initialize TransformerProcessor. @@ -88,6 +89,8 @@ def __init__( ratio of mlp hidden dimension to embedding dimension, default 4 activation : str, optional Activation function, by default "GELU" + dropout_p: float + Dropout probability used for multi-head self attention, default 0.0 """ super().__init__(num_channels=num_channels, num_layers=num_layers) @@ -98,6 +101,7 @@ def __init__( num_heads=num_heads, activation=activation, window_size=window_size, + dropout_p=dropout_p, ) def forward( diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index bb33609..6ba8eb1 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -95,6 +95,7 @@ def __init__( cpu_offload: bool = False, num_heads: int = 16, mlp_hidden_ratio: int = 4, + dropout_p: float = 0.1, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -113,6 +114,8 @@ def __init__( ratio of mlp hidden dimension to embedding dimension, default 4 activation : str, optional Activation function, by default "GELU" + dropout_p: float, optional + Dropout probability used for multi-head self attention, default 0.0 """ super().__init__( num_channels=num_channels, @@ -133,6 +136,7 @@ def __init__( num_layers=self.chunk_size, window_size=window_size, activation=activation, + dropout_p=dropout_p, ) self.offload_layers(cpu_offload) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index a89328c..aa7e8bb 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -67,6 +67,8 @@ def __init__( self._register_latlon("data", self._graph_name_data) self._register_latlon("hidden", self._graph_name_hidden) + self.data_indices = data_indices + self.num_channels = model_config.model.num_channels input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size @@ -103,6 +105,14 @@ def __init__( dst_grid_size=self._data_grid_size, ) + # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) + self.boundings = nn.ModuleList( + [ + instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index) + for cfg in getattr(model_config.model, "bounding", []) + ] + ) + def _calculate_shapes_and_indices(self, data_indices: dict) -> None: self.num_input_channels = len(data_indices.internal_model.input) self.num_output_channels = len(data_indices.internal_model.output) @@ -251,4 +261,9 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # residual connection (just for the prognostic variables) x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + + for bounding in self.boundings: + # bounding performed in the order specified in the config file + x_out = bounding(x_out) + return x_out diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index 53017fb..cc2cb4f 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -38,7 +38,23 @@ def __init__( Data indices for input and output variables statistics : dict Data statistics dictionary + data_indices : dict + Data indices for input and output variables + + Attributes + ---------- + default : str + Default method for variables not specified in the config + method_config : dict + Dictionary of the methods with lists of variables + methods : dict + Dictionary of the variables with methods + data_indices : IndexCollection + Data indices for input and output variables + remap : dict + Dictionary of the variables with remapped names in the config """ + super().__init__() self.default, self.method_config = self._process_config(config) @@ -47,8 +63,10 @@ def __init__( self.data_indices = data_indices def _process_config(self, config): + _special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method. default = config.get("default", "none") - method_config = {k: v for k, v in config.items() if k != "default" and v is not None and v != "none"} + self.remap = config.get("remap", {}) + method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"} if not method_config: LOGGER.warning( diff --git a/src/anemoi/models/preprocessing/normalizer.py b/src/anemoi/models/preprocessing/normalizer.py index bc75466..ee6a4f5 100644 --- a/src/anemoi/models/preprocessing/normalizer.py +++ b/src/anemoi/models/preprocessing/normalizer.py @@ -49,6 +49,16 @@ def __init__( mean = statistics["mean"] stdev = statistics["stdev"] + # Optionally reuse statistic of one variable for another variable + statistics_remap = {} + for remap, source in self.remap.items(): + idx_src, idx_remap = name_to_index_training_input[source], name_to_index_training_input[remap] + statistics_remap[idx_remap] = (minimum[idx_src], maximum[idx_src], mean[idx_src], stdev[idx_src]) + + # Two-step to avoid overwriting the original statistics in the loop (this reduces dependence on order) + for idx, new_stats in statistics_remap.items(): + minimum[idx], maximum[idx], mean[idx], stdev[idx] = new_stats + self._validate_normalization_inputs(name_to_index_training_input, minimum, maximum, mean, stdev) _norm_add = np.zeros((minimum.size,), dtype=np.float32) @@ -56,6 +66,7 @@ def __init__( for name, i in name_to_index_training_input.items(): method = self.methods.get(name, self.default) + if method == "mean-std": LOGGER.debug(f"Normalizing: {name} is mean-std-normalised.") if stdev[i] < (mean[i] * 1e-6): @@ -63,6 +74,13 @@ def __init__( _norm_mul[i] = 1 / stdev[i] _norm_add[i] = -mean[i] / stdev[i] + elif method == "std": + LOGGER.debug(f"Normalizing: {name} is std-normalised.") + if stdev[i] < (mean[i] * 1e-6): + warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}") + _norm_mul[i] = 1 / stdev[i] + _norm_add[i] = 0 + elif method == "min-max": LOGGER.debug(f"Normalizing: {name} is min-max-normalised to [0, 1].") x = maximum[i] - minimum[i] @@ -92,16 +110,20 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min f"Error parsing methods in InputNormalizer methods ({len(self.methods)}) " f"and entries in config ({sum(len(v) for v in self.method_config)}) do not match." ) + + # Check that all sizes align n = minimum.size assert maximum.size == n, (maximum.size, n) assert mean.size == n, (mean.size, n) assert stdev.size == n, (stdev.size, n) + # Check for typos in method config assert isinstance(self.methods, dict) for name, method in self.methods.items(): assert name in name_to_index_training_input, f"{name} is not a valid variable name" assert method in [ "mean-std", + "std", # "robust", "min-max", "max", diff --git a/tests/layers/block/test_block_transformer.py b/tests/layers/block/test_block_transformer.py index 97c274f..2e63386 100644 --- a/tests/layers/block/test_block_transformer.py +++ b/tests/layers/block/test_block_transformer.py @@ -29,11 +29,14 @@ class TestTransformerProcessorBlock: num_heads=st.integers(min_value=1, max_value=10), activation=st.sampled_from(["ReLU", "GELU", "Tanh"]), window_size=st.integers(min_value=1, max_value=512), + dropout_p=st.floats(min_value=0.0, max_value=1.0), ) @settings(max_examples=10) - def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size): + def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p): num_channels = num_heads * factor_attention_heads - block = TransformerProcessorBlock(num_channels, hidden_dim, num_heads, activation, window_size) + block = TransformerProcessorBlock( + num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p + ) assert isinstance(block, TransformerProcessorBlock) assert isinstance(block.layer_norm1, nn.LayerNorm) @@ -49,6 +52,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w window_size=st.integers(min_value=1, max_value=512), shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3), batch_size=st.integers(min_value=1, max_value=40), + dropout_p=st.floats(min_value=0.0, max_value=1.0), ) @settings(max_examples=10) def test_forward_output( @@ -60,9 +64,12 @@ def test_forward_output( window_size, shapes, batch_size, + dropout_p, ): num_channels = num_heads * factor_attention_heads - block = TransformerProcessorBlock(num_channels, hidden_dim, num_heads, activation, window_size) + block = TransformerProcessorBlock( + num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p + ) x = torch.randn((batch_size, num_channels)) diff --git a/tests/layers/chunk/test_chunk_transformer.py b/tests/layers/chunk/test_chunk_transformer.py index 1fe7c6d..5449e97 100644 --- a/tests/layers/chunk/test_chunk_transformer.py +++ b/tests/layers/chunk/test_chunk_transformer.py @@ -20,6 +20,7 @@ def init(self): mlp_hidden_ratio: int = 4 activation: str = "GELU" window_size: int = 13 + dropout_p: float = 0.1 # num_heads must be evenly divisible by num_channels for MHSA return ( @@ -29,6 +30,7 @@ def init(self): mlp_hidden_ratio, activation, window_size, + dropout_p, ) @pytest.fixture @@ -40,6 +42,7 @@ def processor_chunk(self, init): mlp_hidden_ratio, activation, window_size, + dropout_p, ) = init return TransformerProcessorChunk( num_channels=num_channels, @@ -48,6 +51,7 @@ def processor_chunk(self, init): mlp_hidden_ratio=mlp_hidden_ratio, activation=activation, window_size=window_size, + dropout_p=dropout_p, ) def test_all_blocks(self, processor_chunk): diff --git a/tests/layers/processor/test_transformer_processor.py b/tests/layers/processor/test_transformer_processor.py index d359c27..305af41 100644 --- a/tests/layers/processor/test_transformer_processor.py +++ b/tests/layers/processor/test_transformer_processor.py @@ -21,6 +21,7 @@ def transformer_processor_init(): cpu_offload = False num_heads = 16 mlp_hidden_ratio = 4 + dropout_p = 0.1 return ( num_layers, window_size, @@ -30,6 +31,7 @@ def transformer_processor_init(): cpu_offload, num_heads, mlp_hidden_ratio, + dropout_p, ) @@ -44,6 +46,7 @@ def transformer_processor(transformer_processor_init): cpu_offload, num_heads, mlp_hidden_ratio, + dropout_p, ) = transformer_processor_init return TransformerProcessor( num_layers=num_layers, @@ -54,6 +57,7 @@ def transformer_processor(transformer_processor_init): cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, + dropout_p=dropout_p, ) @@ -67,6 +71,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor _cpu_offload, _num_heads, _mlp_hidden_ratio, + _dropout_p, ) = transformer_processor_init assert isinstance(transformer_processor, TransformerProcessor) assert transformer_processor.num_chunks == num_chunks @@ -84,6 +89,7 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces _cpu_offload, _num_heads, _mlp_hidden_ratio, + _dropout_p, ) = transformer_processor_init gridsize = 100 batch_size = 1 diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index ffeaebc..9457317 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -18,17 +18,19 @@ @given( num_heads=st.integers(min_value=1, max_value=50), embed_dim_multiplier=st.integers(min_value=1, max_value=10), + dropout_p=st.floats(min_value=0.0, max_value=1.0), ) -def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier): +def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p): embed_dim = ( num_heads * embed_dim_multiplier ) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads - mhsa = MultiHeadSelfAttention(num_heads, embed_dim) + mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) assert isinstance(mhsa, nn.Module) assert mhsa.num_heads == num_heads assert mhsa.embed_dim == embed_dim assert mhsa.head_dim == embed_dim // num_heads + assert dropout_p == mhsa.dropout_p @pytest.mark.gpu @@ -36,11 +38,12 @@ def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier): batch_size=st.integers(min_value=1, max_value=64), num_heads=st.integers(min_value=1, max_value=20), embed_dim_multiplier=st.integers(min_value=1, max_value=10), + dropout_p=st.floats(min_value=0.0, max_value=1.0), ) @settings(deadline=None) -def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier): +def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p): embed_dim = num_heads * embed_dim_multiplier - mhsa = MultiHeadSelfAttention(num_heads, embed_dim) + mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) x = torch.randn(batch_size * 2, embed_dim) shapes = [list(x.shape)] @@ -54,10 +57,11 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult batch_size=st.integers(min_value=1, max_value=64), num_heads=st.integers(min_value=1, max_value=20), embed_dim_multiplier=st.integers(min_value=1, max_value=10), + dropout_p=st.floats(min_value=0.0, max_value=1.0), ) -def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier): +def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p): embed_dim = num_heads * embed_dim_multiplier - mhsa = MultiHeadSelfAttention(num_heads, embed_dim) + mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) x = torch.randn(batch_size * 2, embed_dim, requires_grad=True) shapes = [list(x.shape)] diff --git a/tests/layers/test_bounding.py b/tests/layers/test_bounding.py new file mode 100644 index 0000000..87619cd --- /dev/null +++ b/tests/layers/test_bounding.py @@ -0,0 +1,92 @@ +import pytest +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate + +from anemoi.models.layers.bounding import FractionBounding +from anemoi.models.layers.bounding import HardtanhBounding +from anemoi.models.layers.bounding import ReluBounding + + +@pytest.fixture +def config(): + return DotDict({"variables": ["var1", "var2"], "total_var": "total_var"}) + + +@pytest.fixture +def name_to_index(): + return {"var1": 0, "var2": 1, "total_var": 2} + + +@pytest.fixture +def input_tensor(): + return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) + + +def test_relu_bounding(config, name_to_index, input_tensor): + bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index) + output = bounding(input_tensor.clone()) + expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + + +def test_hardtanh_bounding(config, name_to_index, input_tensor): + minimum, maximum = -1.0, 1.0 + bounding = HardtanhBounding( + variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum + ) + output = bounding(input_tensor.clone()) + expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + + +def test_fraction_bounding(config, name_to_index, input_tensor): + bounding = FractionBounding( + variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var + ) + output = bounding(input_tensor.clone()) + expected_output = torch.tensor([[0.0, 3.0, 3.0], [6.0, 0.0, 6.0], [0.25, 0.25, 0.5]]) + + assert torch.equal(output, expected_output) + + +def test_multi_chained_bounding(config, name_to_index, input_tensor): + # Apply Relu first on the first variable only + bounding1 = ReluBounding(variables=config.variables[:-1], name_to_index=name_to_index) + expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) + # Check intemediate result + assert torch.equal(bounding1(input_tensor.clone()), expected_output) + minimum, maximum = 0.5, 1.75 + bounding2 = HardtanhBounding( + variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum + ) + # Use full chaining on the input tensor + output = bounding2(bounding1(input_tensor.clone())) + # Data with Relu applied first and then Hardtanh + expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + + +def test_hydra_instantiate_bounding(config, name_to_index, input_tensor): + layer_definitions = [ + { + "_target_": "anemoi.models.layers.bounding.ReluBounding", + "variables": config.variables, + }, + { + "_target_": "anemoi.models.layers.bounding.HardtanhBounding", + "variables": config.variables, + "min_val": 0.0, + "max_val": 1.0, + }, + { + "_target_": "anemoi.models.layers.bounding.FractionBounding", + "variables": config.variables, + "min_val": 0.0, + "max_val": 1.0, + "total_var": config.total_var, + }, + ] + for layer_definition in layer_definitions: + bounding = instantiate(layer_definition, name_to_index=name_to_index) + bounding(input_tensor.clone()) diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py index cc527e7..8056865 100644 --- a/tests/preprocessing/test_preprocessor_normalizer.py +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -40,6 +40,37 @@ def input_normalizer(): return InputNormalizer(config=config.data.normalizer, data_indices=data_indices, statistics=statistics) +@pytest.fixture() +def remap_normalizer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "normalizer": { + "default": "mean-std", + "remap": {"x": "z", "y": "x"}, + "min-max": ["x"], + "max": ["y"], + "none": ["z"], + "mean-std": ["q"], + }, + "forcing": ["z", "q"], + "diagnostic": ["other"], + "remapped": {}, + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0]), + "stdev": np.array([0.5, 0.5, 0.5, 1, 14]), + "minimum": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "maximum": np.array([11.0, 10.0, 10.0, 10.0, 10.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputNormalizer(config=config.data.normalizer, statistics=statistics, data_indices=data_indices) + + def test_normalizer_not_inplace(input_normalizer) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) input_normalizer(x, in_place=False) @@ -87,3 +118,15 @@ def test_normalize_inverse_transform(input_normalizer) -> None: assert torch.allclose( input_normalizer.inverse_transform(input_normalizer.transform(x, in_place=False), in_place=False), x ) + + +def test_normalizer_not_inplace_remap(remap_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + remap_normalizer(x, in_place=False) + assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])) + + +def test_normalize_remap(remap_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + expected_output = torch.Tensor([[0.0, 2 / 11, 3.0, -0.5, 1 / 7], [5 / 9, 7 / 11, 8.0, 4.5, 0.5]]) + assert torch.allclose(remap_normalizer.transform(x), expected_output)