From 70a48f232d796396c88eb0f84eea5e6611dd9b13 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 5 Jun 2024 15:46:23 +0000 Subject: [PATCH] fix: import from utils --- src/anemoi/models/interface/__init__.py | 4 ++-- src/anemoi/models/models/encoder_processor_decoder.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index f4ff82cf..a5f30608 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -10,7 +10,7 @@ import uuid import torch -from anemoi.utils.config import DotConfig +from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch_geometric.data import HeteroData @@ -22,7 +22,7 @@ class AnemoiModelInterface(torch.nn.Module): """Anemoi model on torch level.""" def __init__( - self, *, config: DotConfig, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict ) -> None: super().__init__() self.config = config diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 3bab80e6..a8525913 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -12,7 +12,7 @@ import einops import torch -from anemoi.utils.config import DotConfig +from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch import Tensor from torch import nn @@ -32,7 +32,7 @@ class AnemoiModelEncProcDec(nn.Module): def __init__( self, *, - config: DotConfig, + config: DotDict, data_indices: dict, graph_data: HeteroData, ) -> None: @@ -119,7 +119,7 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotConfig) -> None: + def _define_tensor_sizes(self, config: DotDict) -> None: # Define Sizes of different tensors self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[ 0