diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py index 738b4af6..50a0d08e 100644 --- a/src/anemoi/models/layers/bounding.py +++ b/src/anemoi/models/layers/bounding.py @@ -80,7 +80,7 @@ class ReluBounding(BaseBounding): def forward(self, x: torch.Tensor) -> torch.Tensor: x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) return x - + class NormalizedReluBounding(BaseBounding): """Bounding variable with a ReLU activation and customizable normalized thresholds.""" @@ -125,15 +125,15 @@ def __init__( if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer): raise ValueError( "Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding." - ) + ) if len(self.normalizer) != len(variables): raise ValueError( "The length of the normalizer list must match the number of variables in NormalizedReluBounding." - ) + ) if len(self.min_val) != len(variables): raise ValueError( "The length of the min_val list must match the number of variables in NormalizedReluBounding." - ) + ) self.norm_min_val = torch.zeros(len(variables)) for ii, variable in enumerate(variables): @@ -172,6 +172,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return x + class HardtanhBounding(BaseBounding): """Initializes the bounding with specified minimum and maximum values for bounding. @@ -187,7 +188,16 @@ class HardtanhBounding(BaseBounding): The maximum value for the HardTanh activation. """ - def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None,) -> None: + def __init__( + self, + *, + variables: list[str], + name_to_index: dict, + min_val: float, + max_val: float, + statistics: Optional[dict] = None, + name_to_index_stats: Optional[dict] = None, + ) -> None: super().__init__(variables=variables, name_to_index=name_to_index) self.min_val = min_val self.max_val = max_val @@ -218,8 +228,16 @@ class FractionBounding(HardtanhBounding): """ def __init__( - self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None, - ) -> None: + self, + *, + variables: list[str], + name_to_index: dict, + min_val: float, + max_val: float, + total_var: str, + statistics: Optional[dict] = None, + name_to_index_stats: Optional[dict] = None, + ) -> None: super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) self.total_variable = self._create_index(variables=[total_var])