Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add equality checks to the UQ result readers. #417

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/alfasim_sdk/result_reader/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TimeSetInfoItem(NamedTuple):
uuid: str


TimeSetInfo = Dict[int, TimeSetInfoItem]
TimeSetInfo = Dict[TimeStepIndex, TimeSetInfoItem]


_PROFILE_ID_ATTR = "profile_id"
Expand Down Expand Up @@ -1957,7 +1957,7 @@ def read_uncertainty_propagation_analyses_meta_data(
)


@attr.s(frozen=True)
@attr.s(frozen=True, eq=False)
class UPResult:
"""
Holder for each uncertainty propagation result.
Expand All @@ -1967,6 +1967,16 @@ class UPResult:
std_result: np.ndarray = attr.ib(default=attr.Factory(lambda: np.array([])))
mean_result: np.ndarray = attr.ib(default=attr.Factory(lambda: np.array([])))

def __eq__(self, other: Any) -> bool:
if not isinstance(other, UPResult):
return False

return (
np.array_equal(self.realization_output, other.realization_output)
and np.array_equal(self.std_result, other.std_result)
and np.array_equal(self.mean_result, other.mean_result)
)


def read_uncertainty_propagation_results(
metadata: UncertaintyPropagationAnalysesMetaData,
Expand Down
44 changes: 41 additions & 3 deletions src/alfasim_sdk/result_reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def validator(inst: Any, attribute: attr.Attribute, value: Any) -> None:
return validator


@define(frozen=True)
@define(frozen=True, eq=False)
class GlobalSensitivityAnalysisResults:
timeset: np.ndarray = attr.field(validator=attr.validators.min_len(1))
coefficients: dict[GSAOutputKey, np.ndarray] = attr.field(
Expand Down Expand Up @@ -393,6 +393,22 @@ def get_sensitivity_curve(
domain = Array(self.timeset, "s")
return Curve(image=image, domain=domain)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, GlobalSensitivityAnalysisResults):
return False

return (
np.array_equal(self.timeset, other.timeset)
and _all_dict_close(self.coefficients, other.coefficients)
and self.metadata == other.metadata
)


def _all_dict_close(a: dict[Any, np.ndarray], b: dict[Any, np.ndarray]) -> bool:
if a.keys() != b.keys():
return False
return all(np.array_equal(a[key], b[key]) for key in a)


@define(frozen=True)
class _BaseHistoryMatchingResults:
Expand Down Expand Up @@ -423,7 +439,7 @@ def from_directory(cls, result_dir: Path) -> Self | None:
)


@define(frozen=True)
@define(frozen=True, eq=False)
class HistoryMatchingProbabilisticResults(_BaseHistoryMatchingResults):
probabilistic_distributions: dict[HMOutputKey, np.ndarray] = attr.field(
validator=_non_empty_dict_validator(values_type=np.ndarray)
Expand All @@ -443,6 +459,18 @@ def from_directory(cls, result_dir: Path) -> Self | None:
metadata=metadata,
)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, HistoryMatchingProbabilisticResults):
return False

nicoddemus marked this conversation as resolved.
Show resolved Hide resolved
return (
_all_dict_close(
self.probabilistic_distributions, other.probabilistic_distributions
)
and self.historic_data_curves == other.historic_data_curves
and self.metadata == other.metadata
)


def _read_curves_data(
metadata: HistoryMatchingMetadata,
Expand All @@ -460,7 +488,7 @@ def _read_curves_data(
return result


@define(frozen=True)
@define(frozen=True, eq=False)
class UncertaintyPropagationResults:
timeset: np.ndarray = attr.field(validator=attr.validators.min_len(1))
results: dict[UPOutputKey, UPResult] = attr.field(
Expand All @@ -481,3 +509,13 @@ def from_directory(cls, result_dir: Path) -> Self | None:
results=read_uncertainty_propagation_results(metadata),
metadata=metadata,
)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, UncertaintyPropagationResults):
return False

return (
np.array_equal(self.timeset, other.timeset)
and self.results == other.results
and self.metadata == other.metadata
)
28 changes: 28 additions & 0 deletions tests/results/test_result_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from pathlib import Path

import attr
import numpy
import numpy as np
import pytest
Expand Down Expand Up @@ -188,6 +189,11 @@ def test_global_sensitivity_analysis_results_reader(
qoi_data_index=0,
)

# Test equality check.
assert results == attr.evolve(results)
assert results != object()
assert results != attr.evolve(results, timeset=np.array([0.1, 0.2]))

# Ensure the reader can handle a nonexistent result file.
results = GlobalSensitivityAnalysisResults.from_directory(Path("foo"))
assert results is None
Expand All @@ -204,6 +210,13 @@ def test_deterministic_reader(self, hm_deterministic_results_dir: Path) -> None:
}
self._validate_meta_and_historic_curves(results)

# Test equality check.
assert results == attr.evolve(results)
assert results != object()
assert results != attr.evolve(
results, deterministic_values={HMOutputKey("parametric_var_1"): 0.1}
)

def test_probabilistic_reader(self, hm_probabilistic_results_dir: Path) -> None:
results = HistoryMatchingProbabilisticResults.from_directory(
hm_probabilistic_results_dir
Expand All @@ -217,6 +230,16 @@ def test_probabilistic_reader(self, hm_probabilistic_results_dir: Path) -> None:
)
self._validate_meta_and_historic_curves(results)

# Test equality check.
assert results == attr.evolve(results)
assert results != object()
assert results != attr.evolve(
results,
probabilistic_distributions={
HMOutputKey("parametric_var_1"): np.array([0.1, 0.3])
},
)

def test_wrong_result_file(
self, hm_probabilistic_results_dir: Path, hm_deterministic_results_dir: Path
) -> None:
Expand Down Expand Up @@ -331,6 +354,11 @@ def test_uncertainty_propagation_results_reader(up_results_dir: Path) -> None:
sample_indexes=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]],
)

# Test equality check.
assert reader == attr.evolve(reader)
assert reader != object()
assert reader != attr.evolve(reader, timeset=np.array([0.1, 0.2]))

# Ensure the reader can handle a nonexistent result file.
reader = UncertaintyPropagationResults.from_directory(Path("foo"))
assert reader is None
Loading