Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
fix: argument order and explicit types
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed May 28, 2024
1 parent 9500a13 commit 92ad881
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing import BasePreprocessor

LOGGER = logging.getLogger(__name__)
Expand All @@ -24,8 +25,8 @@ class BaseImputer(BasePreprocessor, ABC):
def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
data_indices: Optional[dict] = None,
) -> None:
"""Initialize the imputer.
Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(
statistics: Optional[dict] = None,
data_indices: Optional[dict] = None,
) -> None:
super().__init__(config, statistics, data_indices)
super().__init__(config, data_indices, statistics)

self._create_imputation_indices(statistics)

Expand All @@ -199,8 +200,10 @@ class ConstantImputer(BaseImputer):
```
"""

def __init__(self, config=None, statistics: Optional[dict] = None, data_indices: Optional[dict] = None) -> None:
super().__init__(config, statistics, data_indices)
def __init__(
self, config=None, statistics: Optional[dict] = None, data_indices: Optional[IndexCollection] = None
) -> None:
super().__init__(config, data_indices, statistics)

self._create_imputation_indices()

Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing import BasePreprocessor

LOGGER = logging.getLogger(__name__)
Expand All @@ -25,7 +26,7 @@ class InputNormalizer(BasePreprocessor):
def __init__(
self,
config=None,
data_indices: Optional[dict] = None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
"""Initialize the normalizer.
Expand Down

0 comments on commit 92ad881

Please sign in to comment.