Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
fix: config import
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed May 30, 2024
1 parent 42b71d0 commit 98b7ba5
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid

import torch
from anemoi.utils.config import DotConfig
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

Expand All @@ -19,10 +19,37 @@


class AnemoiModelInterface(torch.nn.Module):
"""Anemoi model on torch level."""
"""An interface for Anemoi models.
This class is a wrapper around the Anemoi model that includes pre-processing and post-processing steps.
It inherits from the PyTorch Module class.
Attributes
----------
config : DotConfig
Configuration settings for the model.
id : str
A unique identifier for the model instance.
multi_step : bool
Whether the model uses multi-step input.
graph_data : HeteroData
Graph data for the model.
statistics : dict
Statistics for the data.
metadata : dict
Metadata for the model.
data_indices : dict
Indices for the data.
pre_processors : Processors
Pre-processing steps to apply to the data before passing it to the model.
post_processors : Processors
Post-processing steps to apply to the model's output.
model : AnemoiModelEncProcDec
The underlying Anemoi model.
"""

def __init__(
self, *, config: DotConfig, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict
self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict
) -> None:
super().__init__()
self.config = config
Expand All @@ -35,7 +62,7 @@ def __init__(
self._build_model()

def _build_model(self) -> None:
"""Build the model and pre- and post-processors."""
"""Builds the model and pre- and post-processors."""
# Instantiate processors
processors = [
[name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)]
Expand Down

0 comments on commit 98b7ba5

Please sign in to comment.