Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 16, 2024
1 parent abb0afd commit c808dcf
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ReluBounding(BaseBounding):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index])
return x


class NormalizedReluBounding(BaseBounding):
"""Bounding variable with a ReLU activation and customizable normalized thresholds."""
Expand Down Expand Up @@ -125,15 +125,15 @@ def __init__(
if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer):
raise ValueError(
"Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding."
)
)
if len(self.normalizer) != len(variables):
raise ValueError(
"The length of the normalizer list must match the number of variables in NormalizedReluBounding."
)
)
if len(self.min_val) != len(variables):
raise ValueError(
"The length of the min_val list must match the number of variables in NormalizedReluBounding."
)
)

self.norm_min_val = torch.zeros(len(variables))
for ii, variable in enumerate(variables):
Expand Down Expand Up @@ -172,6 +172,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
return x


class HardtanhBounding(BaseBounding):
"""Initializes the bounding with specified minimum and maximum values for bounding.
Expand All @@ -187,7 +188,16 @@ class HardtanhBounding(BaseBounding):
The maximum value for the HardTanh activation.
"""

def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None,) -> None:
def __init__(
self,
*,
variables: list[str],
name_to_index: dict,
min_val: float,
max_val: float,
statistics: Optional[dict] = None,
name_to_index_stats: Optional[dict] = None,
) -> None:
super().__init__(variables=variables, name_to_index=name_to_index)
self.min_val = min_val
self.max_val = max_val
Expand Down Expand Up @@ -218,8 +228,16 @@ class FractionBounding(HardtanhBounding):
"""

def __init__(
self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None,
) -> None:
self,
*,
variables: list[str],
name_to_index: dict,
min_val: float,
max_val: float,
total_var: str,
statistics: Optional[dict] = None,
name_to_index_stats: Optional[dict] = None,
) -> None:
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val)
self.total_variable = self._create_index(variables=[total_var])

Expand Down

0 comments on commit c808dcf

Please sign in to comment.