From 33dccb9d4387724290bce6d3c62fc22a68449fc7 Mon Sep 17 00:00:00 2001 From: sahahner Date: Wed, 14 Aug 2024 13:07:21 +0000 Subject: [PATCH] tests: include remapping forcing variable and do not test with remapping variables at the end --- tests/data_indices/test_collection.py | 79 +++++++++++-------- .../test_preprocessor_remapper.py | 12 +-- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/tests/data_indices/test_collection.py b/tests/data_indices/test_collection.py index 69bf5fc..78b3ecd 100644 --- a/tests/data_indices/test_collection.py +++ b/tests/data_indices/test_collection.py @@ -17,46 +17,49 @@ def data_indices(): config = DictConfig( { "data": { - "forcing": ["x"], + "forcing": ["x", "e"], "diagnostic": ["z", "q"], "remapped": [ { + "e": ["e_1", "e_2"], "d": ["d_1", "d_2"], } ], }, }, ) - name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6} return IndexCollection(config=config, name_to_index=name_to_index) def test_dataindices_init(data_indices) -> None: - assert data_indices.data.input.includes == ["x"] + assert data_indices.data.input.includes == ["x", "e"] assert data_indices.data.input.excludes == ["z", "q"] - assert data_indices.internal_data.input.includes == ["x"] + assert data_indices.internal_data.input.includes == ["x", "e_1", "e_2"] assert data_indices.internal_data.input.excludes == ["z", "q"] assert data_indices.internal_data.output.includes == ["z", "q"] - assert data_indices.internal_data.output.excludes == ["x"] + assert data_indices.internal_data.output.excludes == ["x", "e_1", "e_2"] assert data_indices.data.output.includes == ["z", "q"] - assert data_indices.data.output.excludes == ["x"] - assert data_indices.model.input.includes == ["x"] + assert data_indices.data.output.excludes == ["x", "e"] + assert data_indices.model.input.includes == ["x", "e"] assert data_indices.model.input.excludes == [] - assert data_indices.internal_model.input.includes == ["x"] + assert data_indices.internal_model.input.includes == ["x", "e_1", "e_2"] assert data_indices.internal_model.input.excludes == [] assert data_indices.internal_model.output.includes == ["z", "q"] assert data_indices.internal_model.output.excludes == [] assert data_indices.model.output.includes == ["z", "q"] assert data_indices.model.output.excludes == [] - assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5} + assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6} assert data_indices.internal_data.input.name_to_index == { "x": 0, "y": 1, "z": 2, "q": 3, "other": 4, - "d_1": 5, - "d_2": 6, + "e_1": 5, + "e_2": 6, + "d_1": 7, + "d_2": 8, } assert data_indices.internal_data.output.name_to_index == { "x": 0, @@ -64,14 +67,24 @@ def test_dataindices_init(data_indices) -> None: "z": 2, "q": 3, "other": 4, + "e_1": 5, + "e_2": 6, + "d_1": 7, + "d_2": 8, + } + assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6} + assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "e": 2, "d": 3, "other": 4} + assert data_indices.internal_model.input.name_to_index == { + "x": 0, + "y": 1, + "other": 2, + "e_1": 3, + "e_2": 4, "d_1": 5, "d_2": 6, } - assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5} - assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "other": 2, "d": 3} - assert data_indices.internal_model.input.name_to_index == {"x": 0, "y": 1, "other": 2, "d_1": 3, "d_2": 4} assert data_indices.internal_model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3, "d_1": 4, "d_2": 5} - assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3, "d": 4} + assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "d": 3, "other": 4} def test_dataindices_max(data_indices) -> None: @@ -90,16 +103,16 @@ def test_dataindices_max(data_indices) -> None: def test_dataindices_todict(data_indices) -> None: expected_output = { "input": { - "full": torch.Tensor([0, 1, 4, 5]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([0, 1, 4, 5, 6]).to(torch.int), + "forcing": torch.Tensor([0, 4]).to(torch.int), "diagnostic": torch.Tensor([2, 3]).to(torch.int), - "prognostic": torch.Tensor([1, 4, 5]).to(torch.int), + "prognostic": torch.Tensor([1, 5, 6]).to(torch.int), }, "output": { - "full": torch.Tensor([1, 2, 3, 4, 5]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([1, 2, 3, 5, 6]).to(torch.int), + "forcing": torch.Tensor([0, 4]).to(torch.int), "diagnostic": torch.Tensor([2, 3]).to(torch.int), - "prognostic": torch.Tensor([1, 4, 5]).to(torch.int), + "prognostic": torch.Tensor([1, 5, 6]).to(torch.int), }, } @@ -112,16 +125,16 @@ def test_dataindices_todict(data_indices) -> None: def test_internaldataindices_todict(data_indices) -> None: expected_output = { "input": { - "full": torch.Tensor([0, 1, 4, 5, 6]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([0, 1, 4, 5, 6, 7, 8]).to(torch.int), + "forcing": torch.Tensor([0, 5, 6]).to(torch.int), "diagnostic": torch.Tensor([2, 3]).to(torch.int), - "prognostic": torch.Tensor([1, 4, 5, 6]).to(torch.int), + "prognostic": torch.Tensor([1, 4, 7, 8]).to(torch.int), }, "output": { - "full": torch.Tensor([1, 2, 3, 4, 5, 6]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([1, 2, 3, 4, 7, 8]).to(torch.int), + "forcing": torch.Tensor([0, 5, 6]).to(torch.int), "diagnostic": torch.Tensor([2, 3]).to(torch.int), - "prognostic": torch.Tensor([1, 4, 5, 6]).to(torch.int), + "prognostic": torch.Tensor([1, 4, 7, 8]).to(torch.int), }, } @@ -134,10 +147,10 @@ def test_internaldataindices_todict(data_indices) -> None: def test_modelindices_todict(data_indices) -> None: expected_output = { "input": { - "full": torch.Tensor([0, 1, 2, 3]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int), + "forcing": torch.Tensor([0, 2]).to(torch.int), "diagnostic": torch.Tensor([]).to(torch.int), - "prognostic": torch.Tensor([1, 2, 3]).to(torch.int), + "prognostic": torch.Tensor([1, 3, 4]).to(torch.int), }, "output": { "full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int), @@ -156,10 +169,10 @@ def test_modelindices_todict(data_indices) -> None: def test_internalmodelindices_todict(data_indices) -> None: expected_output = { "input": { - "full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([0, 1, 2, 3, 4, 5, 6]).to(torch.int), + "forcing": torch.Tensor([0, 3, 4]).to(torch.int), "diagnostic": torch.Tensor([]).to(torch.int), - "prognostic": torch.Tensor([1, 2, 3, 4]).to(torch.int), + "prognostic": torch.Tensor([1, 2, 5, 6]).to(torch.int), }, "output": { "full": torch.Tensor([0, 1, 2, 3, 4, 5]).to(torch.int), diff --git a/tests/preprocessing/test_preprocessor_remapper.py b/tests/preprocessing/test_preprocessor_remapper.py index d4bcbb4..d5bd52a 100644 --- a/tests/preprocessing/test_preprocessor_remapper.py +++ b/tests/preprocessing/test_preprocessor_remapper.py @@ -37,19 +37,19 @@ def input_remapper(): }, ) statistics = {} - name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} data_indices = IndexCollection(config=config, name_to_index=name_to_index) return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) def test_remap_not_inplace(input_remapper) -> None: - x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]]) + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) input_remapper(x, in_place=False) - assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]])) + assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])) def test_remap(input_remapper) -> None: - x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]]) + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) expected_output = torch.Tensor( [[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]] ) @@ -58,12 +58,12 @@ def test_remap(input_remapper) -> None: def test_inverse_transform(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]]) - expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]]) + expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose(input_remapper.inverse_transform(x), expected_output) def test_remap_inverse_transform(input_remapper) -> None: - x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]]) + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose( input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x )