diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index c15d2ca3..85545e6e 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -10,7 +10,7 @@ import uuid import torch -from anemoi.utils.config import DotConfig +from anemoi.utils.config import DotDict from hydra.utils import instantiate from anemoi.models.data_indices.collection import IndexCollection @@ -22,7 +22,7 @@ class AnemoiModelInterface(torch.nn.Module): """Anemoi model on torch level.""" def __init__( - self, *, config: DotConfig, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict + self, *, config: DotDict, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict ) -> None: super().__init__() self.config = config diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 682e1058..8203bd84 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -12,7 +12,7 @@ import einops import torch -from anemoi.utils.config import DotConfig +from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch import Tensor from torch import nn @@ -32,7 +32,7 @@ class AnemoiModelEncProcDec(nn.Module): def __init__( self, *, - config: DotConfig, + config: DotDict, data_indices: IndexCollection, graph_data: dict, ) -> None: @@ -130,7 +130,7 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotConfig, graph_data: dict) -> None: + def _define_tensor_sizes(self, config: DotDict, graph_data: dict) -> None: # Define Sizes of different tensors self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names}