-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/develop' into feature/28-make-mo…
…dels-switchable-through-the-config
- Loading branch information
Showing
22 changed files
with
425 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CHANGELOG.md merge=union |
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# CODEOWNERS file | ||
|
||
# Protect workflow files | ||
/.github/ @theissenhelen @jesperdramsch @gmertes | ||
/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes | ||
/pyproject.toml @theissenhelen @jesperdramsch @gmertes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
build: | ||
python: '3.10' | ||
modules: | ||
- ninja | ||
python_dependencies: | ||
- ecmwf/anemoi-utils@develop | ||
parallel: 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from anemoi.models.data_indices.tensor import InputTensorIndex | ||
|
||
|
||
class BaseBounding(nn.Module, ABC): | ||
"""Abstract base class for bounding strategies. | ||
This class defines an interface for bounding strategies which are used to apply a specific | ||
restriction to the predictions of a model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
variables: list[str], | ||
name_to_index: dict, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.name_to_index = name_to_index | ||
self.variables = variables | ||
self.data_index = self._create_index(variables=self.variables) | ||
|
||
def _create_index(self, variables: list[str]) -> InputTensorIndex: | ||
return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only | ||
|
||
@abstractmethod | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Applies the bounding to the predictions. | ||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
The tensor containing the predictions that will be bounded. | ||
Returns | ||
------- | ||
torch.Tensor | ||
A tensor with the bounding applied. | ||
""" | ||
pass | ||
|
||
|
||
class ReluBounding(BaseBounding): | ||
"""Initializes the bounding with a ReLU activation / zero clamping.""" | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) | ||
return x | ||
|
||
|
||
class HardtanhBounding(BaseBounding): | ||
"""Initializes the bounding with specified minimum and maximum values for bounding. | ||
Parameters | ||
---------- | ||
variables : list[str] | ||
A list of strings representing the variables that will be bounded. | ||
name_to_index : dict | ||
A dictionary mapping the variable names to their corresponding indices. | ||
min_val : float | ||
The minimum value for the HardTanh activation. | ||
max_val : float | ||
The maximum value for the HardTanh activation. | ||
""" | ||
|
||
def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None: | ||
super().__init__(variables=variables, name_to_index=name_to_index) | ||
self.min_val = min_val | ||
self.max_val = max_val | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x[..., self.data_index] = torch.nn.functional.hardtanh( | ||
x[..., self.data_index], min_val=self.min_val, max_val=self.max_val | ||
) | ||
return x | ||
|
||
|
||
class FractionBounding(HardtanhBounding): | ||
"""Initializes the FractionBounding with specified parameters. | ||
Parameters | ||
---------- | ||
variables : list[str] | ||
A list of strings representing the variables that will be bounded. | ||
name_to_index : dict | ||
A dictionary mapping the variable names to their corresponding indices. | ||
min_val : float | ||
The minimum value for the HardTanh activation. | ||
max_val : float | ||
The maximum value for the HardTanh activation. | ||
total_var : str | ||
A string representing a variable from which a secondary variable is derived. For | ||
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation). | ||
""" | ||
|
||
def __init__( | ||
self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str | ||
) -> None: | ||
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) | ||
self.total_variable = self._create_index(variables=[total_var]) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
# Apply the HardTanh bounding to the data_index variables | ||
x = super().forward(x) | ||
# Calculate the fraction of the total variable | ||
x[..., self.data_index] *= x[..., self.total_variable] | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.