Skip to content

Commit

Permalink
tests: include remapping forcing variable and do not test with remapp…
Browse files Browse the repository at this point in the history
…ing variables at the end
  • Loading branch information
sahahner committed Aug 14, 2024
1 parent c8c152e commit 33dccb9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
79 changes: 46 additions & 33 deletions tests/data_indices/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,74 @@ def data_indices():
config = DictConfig(
{
"data": {
"forcing": ["x"],
"forcing": ["x", "e"],
"diagnostic": ["z", "q"],
"remapped": [
{
"e": ["e_1", "e_2"],
"d": ["d_1", "d_2"],
}
],
},
},
)
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5}
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6}
return IndexCollection(config=config, name_to_index=name_to_index)


def test_dataindices_init(data_indices) -> None:
assert data_indices.data.input.includes == ["x"]
assert data_indices.data.input.includes == ["x", "e"]
assert data_indices.data.input.excludes == ["z", "q"]
assert data_indices.internal_data.input.includes == ["x"]
assert data_indices.internal_data.input.includes == ["x", "e_1", "e_2"]
assert data_indices.internal_data.input.excludes == ["z", "q"]
assert data_indices.internal_data.output.includes == ["z", "q"]
assert data_indices.internal_data.output.excludes == ["x"]
assert data_indices.internal_data.output.excludes == ["x", "e_1", "e_2"]
assert data_indices.data.output.includes == ["z", "q"]
assert data_indices.data.output.excludes == ["x"]
assert data_indices.model.input.includes == ["x"]
assert data_indices.data.output.excludes == ["x", "e"]
assert data_indices.model.input.includes == ["x", "e"]
assert data_indices.model.input.excludes == []
assert data_indices.internal_model.input.includes == ["x"]
assert data_indices.internal_model.input.includes == ["x", "e_1", "e_2"]
assert data_indices.internal_model.input.excludes == []
assert data_indices.internal_model.output.includes == ["z", "q"]
assert data_indices.internal_model.output.excludes == []
assert data_indices.model.output.includes == ["z", "q"]
assert data_indices.model.output.excludes == []
assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5}
assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6}
assert data_indices.internal_data.input.name_to_index == {
"x": 0,
"y": 1,
"z": 2,
"q": 3,
"other": 4,
"d_1": 5,
"d_2": 6,
"e_1": 5,
"e_2": 6,
"d_1": 7,
"d_2": 8,
}
assert data_indices.internal_data.output.name_to_index == {
"x": 0,
"y": 1,
"z": 2,
"q": 3,
"other": 4,
"e_1": 5,
"e_2": 6,
"d_1": 7,
"d_2": 8,
}
assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6}
assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "e": 2, "d": 3, "other": 4}
assert data_indices.internal_model.input.name_to_index == {
"x": 0,
"y": 1,
"other": 2,
"e_1": 3,
"e_2": 4,
"d_1": 5,
"d_2": 6,
}
assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5}
assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "other": 2, "d": 3}
assert data_indices.internal_model.input.name_to_index == {"x": 0, "y": 1, "other": 2, "d_1": 3, "d_2": 4}
assert data_indices.internal_model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3, "d_1": 4, "d_2": 5}
assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3, "d": 4}
assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "d": 3, "other": 4}


def test_dataindices_max(data_indices) -> None:
Expand All @@ -90,16 +103,16 @@ def test_dataindices_max(data_indices) -> None:
def test_dataindices_todict(data_indices) -> None:
expected_output = {
"input": {
"full": torch.Tensor([0, 1, 4, 5]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"full": torch.Tensor([0, 1, 4, 5, 6]).to(torch.int),
"forcing": torch.Tensor([0, 4]).to(torch.int),
"diagnostic": torch.Tensor([2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 4, 5]).to(torch.int),
"prognostic": torch.Tensor([1, 5, 6]).to(torch.int),
},
"output": {
"full": torch.Tensor([1, 2, 3, 4, 5]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"full": torch.Tensor([1, 2, 3, 5, 6]).to(torch.int),
"forcing": torch.Tensor([0, 4]).to(torch.int),
"diagnostic": torch.Tensor([2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 4, 5]).to(torch.int),
"prognostic": torch.Tensor([1, 5, 6]).to(torch.int),
},
}

Expand All @@ -112,16 +125,16 @@ def test_dataindices_todict(data_indices) -> None:
def test_internaldataindices_todict(data_indices) -> None:
expected_output = {
"input": {
"full": torch.Tensor([0, 1, 4, 5, 6]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"full": torch.Tensor([0, 1, 4, 5, 6, 7, 8]).to(torch.int),
"forcing": torch.Tensor([0, 5, 6]).to(torch.int),
"diagnostic": torch.Tensor([2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 4, 5, 6]).to(torch.int),
"prognostic": torch.Tensor([1, 4, 7, 8]).to(torch.int),
},
"output": {
"full": torch.Tensor([1, 2, 3, 4, 5, 6]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"full": torch.Tensor([1, 2, 3, 4, 7, 8]).to(torch.int),
"forcing": torch.Tensor([0, 5, 6]).to(torch.int),
"diagnostic": torch.Tensor([2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 4, 5, 6]).to(torch.int),
"prognostic": torch.Tensor([1, 4, 7, 8]).to(torch.int),
},
}

Expand All @@ -134,10 +147,10 @@ def test_internaldataindices_todict(data_indices) -> None:
def test_modelindices_todict(data_indices) -> None:
expected_output = {
"input": {
"full": torch.Tensor([0, 1, 2, 3]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int),
"forcing": torch.Tensor([0, 2]).to(torch.int),
"diagnostic": torch.Tensor([]).to(torch.int),
"prognostic": torch.Tensor([1, 2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 3, 4]).to(torch.int),
},
"output": {
"full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int),
Expand All @@ -156,10 +169,10 @@ def test_modelindices_todict(data_indices) -> None:
def test_internalmodelindices_todict(data_indices) -> None:
expected_output = {
"input": {
"full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"full": torch.Tensor([0, 1, 2, 3, 4, 5, 6]).to(torch.int),
"forcing": torch.Tensor([0, 3, 4]).to(torch.int),
"diagnostic": torch.Tensor([]).to(torch.int),
"prognostic": torch.Tensor([1, 2, 3, 4]).to(torch.int),
"prognostic": torch.Tensor([1, 2, 5, 6]).to(torch.int),
},
"output": {
"full": torch.Tensor([0, 1, 2, 3, 4, 5]).to(torch.int),
Expand Down
12 changes: 6 additions & 6 deletions tests/preprocessing/test_preprocessor_remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ def input_remapper():
},
)
statistics = {}
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4, "d": 5}
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5}
data_indices = IndexCollection(config=config, name_to_index=name_to_index)
return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics)


def test_remap_not_inplace(input_remapper) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]])
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])
input_remapper(x, in_place=False)
assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]]))
assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]))


def test_remap(input_remapper) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]])
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])
expected_output = torch.Tensor(
[[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]]
)
Expand All @@ -58,12 +58,12 @@ def test_remap(input_remapper) -> None:

def test_inverse_transform(input_remapper) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]])
expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]])
expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])
assert torch.allclose(input_remapper.inverse_transform(x), expected_output)


def test_remap_inverse_transform(input_remapper) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 150.0], [6.0, 7.0, 8.0, 9.0, 10.0, 201.0]])
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])
assert torch.allclose(
input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x
)

0 comments on commit 33dccb9

Please sign in to comment.