Skip to content

Commit

Permalink
docs: add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
JesperDramsch committed Sep 23, 2024
1 parent 7da5bbe commit 9c28777
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,14 @@ class BaseBounding(nn.Module, ABC):
This class defines an interface for bounding strategies which are used to apply a specific
restriction to the predictions of a model.
Parameters
----------
x : torch.Tensor
The tensor containing the predictions that will be bounded.
Returns
-------
torch.Tensor
A tensor with the bounding applied.
"""

def __init__(
self,
*,
variables: list[str],
name_to_index: dict,
):
) -> None:
super().__init__()

self.name_to_index = name_to_index
Expand All @@ -43,10 +33,24 @@ def _create_index(self, variables: list[str]) -> InputTensorIndex:

@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the bounding to the predictions.
Parameters
----------
x : torch.Tensor
The tensor containing the predictions that will be bounded.
Returns
-------
torch.Tensor
A tensor with the bounding applied.
"""
pass


class ReluBounding(BaseBounding):
"""Initializes the bounding with a ReLU activation / zero clamping.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index])
return x
Expand All @@ -57,13 +61,17 @@ class HardtanhBounding(BaseBounding):
Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
"""

def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float):
def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None:
super().__init__(variables=variables, name_to_index=name_to_index)
self.min_val = min_val
self.max_val = max_val
Expand All @@ -80,12 +88,20 @@ class FractionBounding(HardtanhBounding):
Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
total_var : str
A string representing a variable from which a secondary variable is derived. For
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation).
"""

def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str):
def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str) -> None:
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val)
self.total_variable = self._create_index(variables=[total_var])

Expand Down

0 comments on commit 9c28777

Please sign in to comment.