Skip to content

Commit

Permalink
feat: enable remapper for forcing variables
Browse files Browse the repository at this point in the history
  • Loading branch information
sahahner committed Aug 13, 2024
1 parent b633a69 commit c8c152e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
18 changes: 12 additions & 6 deletions src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def __init__(self, config, name_to_index) -> None:
chain.from_iterable(d.items() for d in OmegaConf.to_container(config.data.remapped, resolve=True))
)
)
self.forcing_remapped = self.forcing.copy()

assert set(self.diagnostic).isdisjoint(self.forcing), (
f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ",
"Please drop them at a dataset-level to exclude them from the training data.",
)
assert set(self.remapped).isdisjoint(self.forcing) and set(self.remapped).isdisjoint(self.diagnostic), (
"Remapped variable overlap with diagnostic and forcing variables. ",
"Not implemented.",
assert set(self.remapped).isdisjoint(self.diagnostic), (
"Remapped variable overlap with diagnostic variables. Not implemented.",
)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
name_to_index_internal_data_input = {
Expand All @@ -67,19 +67,25 @@ def __init__(self, config, name_to_index) -> None:
for key in self.remapped:
for mapped in self.remapped[key]:
name_to_index_internal_model_input[mapped] = len(name_to_index_internal_model_input)
name_to_index_internal_model_output[mapped] = len(name_to_index_internal_model_output)
name_to_index_internal_data_input[mapped] = len(name_to_index_internal_data_input)
if key not in self.forcing:
name_to_index_internal_model_output[mapped] = len(name_to_index_internal_model_output)
else:
self.forcing_remapped += [mapped]
if key in self.forcing:
# if key is in forcing we need to remove it from forcing_remapped after remapped variables have been added
self.forcing_remapped.remove(key)

self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index)
self.internal_data = DataIndex(
self.diagnostic,
self.forcing,
self.forcing_remapped,
name_to_index_internal_data_input,
) # internal after the remapping applied to data (training)
self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output)
self.internal_model = ModelIndex(
self.diagnostic,
self.forcing,
self.forcing_remapped,
name_to_index_internal_model_input,
name_to_index_internal_model_output,
) # internal after the remapping applied to model (inference)
Expand Down
15 changes: 11 additions & 4 deletions src/anemoi/models/preprocessing/remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ def _create_remapping_indices(
self.index_training_input.append(name_to_index_training_input[name])
self.index_training_output.append(name_to_index_training_output[name])
self.index_inference_input.append(name_to_index_inference_input[name])
self.index_inference_output.append(name_to_index_inference_output[name])

if name in name_to_index_inference_output:
self.index_inference_output.append(name_to_index_inference_output[name])
else:
# this is a forcing variable. It is not in the inference output.
self.index_inference_output.append(None)
multiple_training_output, multiple_inference_output = [], []
multiple_training_input, multiple_inference_input = [], []
for name_dst in self.method_config[method][name]:
Expand All @@ -152,9 +155,13 @@ def _create_remapping_indices(
f"Remap {name} to {name_dst} in config.data.remapped. "
)
multiple_training_input.append(name_to_index_training_remapped_input[name_dst])
multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst])
multiple_training_output.append(name_to_index_training_remapped_output[name_dst])
multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst])
multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst])
if name_dst in name_to_index_inference_remapped_output:
multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst])
else:
# this is a forcing variable. It is not in the inference output.
multiple_inference_output.append(None)

self.index_training_remapped_input.append(multiple_training_input)
self.index_inference_remapped_input.append(multiple_inference_input)
Expand Down

0 comments on commit c8c152e

Please sign in to comment.