Skip to content

Commit

Permalink
Add unit tests for repo save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma committed Oct 17, 2024
1 parent 697b477 commit 8b8e06c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
5 changes: 3 additions & 2 deletions tabrepo/contexts/context_artificial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def load_context_artificial(
problem_type: str = "regression",
seed=0,
include_hyperparameters: bool = False,
dtype=np.float32,
**kwargs,
):
# TODO write specification of dataframes schema, this code produces a minimal example that enables
Expand Down Expand Up @@ -88,11 +89,11 @@ def load_context_artificial(
dataset_name: {
fold: {
"pred_proba_dict_val": {
m: rng.random((123, n_classes)) if n_classes > 2 else rng.random(123)
m: rng.random((123, n_classes), dtype=dtype) if n_classes > 2 else rng.random(123, dtype=dtype)
for m in models
},
"pred_proba_dict_test": {
m: rng.random((13, n_classes)) if n_classes > 2 else rng.random(13)
m: rng.random((13, n_classes), dtype=dtype) if n_classes > 2 else rng.random(13, dtype=dtype)
for m in models
}
}
Expand Down
41 changes: 36 additions & 5 deletions tst/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
def verify_equivalent_repository(
repo1: EvaluationRepository | EvaluationRepositoryCollection,
repo2: EvaluationRepository | EvaluationRepositoryCollection,
exact: bool = True,
verify_ensemble: bool = False,
verify_configs_hyperparameters: bool = True,
backend: str = "native",
):
assert repo1.folds == repo2.folds
Expand All @@ -27,15 +29,25 @@ def verify_equivalent_repository(
repo2_test = repo2.predict_test(dataset=dataset, config=c, fold=f)
repo1_val = repo1.predict_val(dataset=dataset, config=c, fold=f)
repo2_val = repo2.predict_val(dataset=dataset, config=c, fold=f)
assert np.array_equal(repo1_test, repo2_test)
assert np.array_equal(repo1_val, repo2_val)
assert np.array_equal(repo1.labels_test(dataset=dataset, fold=f), repo2.labels_test(dataset=dataset, fold=f))
assert np.array_equal(repo1.labels_val(dataset=dataset, fold=f), repo2.labels_val(dataset=dataset, fold=f))
if exact:
assert np.array_equal(repo1_test, repo2_test)
assert np.array_equal(repo1_val, repo2_val)
else:
assert np.isclose(repo1_test, repo2_test).all()
assert np.isclose(repo1_val, repo2_val).all()
if exact:
assert np.array_equal(repo1.labels_test(dataset=dataset, fold=f), repo2.labels_test(dataset=dataset, fold=f))
assert np.array_equal(repo1.labels_val(dataset=dataset, fold=f), repo2.labels_val(dataset=dataset, fold=f))
else:
assert np.isclose(repo1.labels_test(dataset=dataset, fold=f), repo2.labels_test(dataset=dataset, fold=f)).all()
assert np.isclose(repo1.labels_val(dataset=dataset, fold=f), repo2.labels_val(dataset=dataset, fold=f)).all()
if verify_ensemble:
df_out_1, df_ensemble_weights_1 = repo1.evaluate_ensembles(datasets=repo1.datasets(), ensemble_size=10, backend=backend)
df_out_2, df_ensemble_weights_2 = repo2.evaluate_ensembles(datasets=repo2.datasets(), ensemble_size=10, backend=backend)
assert df_out_1.equals(df_out_2)
assert df_ensemble_weights_1.equals(df_ensemble_weights_2)
if verify_configs_hyperparameters:
assert repo1.configs_hyperparameters() == repo2.configs_hyperparameters()


def test_repository():
Expand Down Expand Up @@ -185,7 +197,10 @@ def test_repository_subset():
def test_repository_configs_hyperparameters():
repo1 = load_repo_artificial()
repo2 = load_repo_artificial(include_hyperparameters=True)
verify_equivalent_repository(repo1, repo2, verify_ensemble=True)
verify_equivalent_repository(repo1, repo2, verify_ensemble=True, verify_configs_hyperparameters=False)

with pytest.raises(AssertionError):
verify_equivalent_repository(repo1, repo2, verify_configs_hyperparameters=True)

configs = ['NeuralNetFastAI_r1', 'NeuralNetFastAI_r2']

Expand Down Expand Up @@ -241,6 +256,22 @@ def test_repository_configs_hyperparameters():
]}


def test_repository_save_load():
"""test repo save and load work"""
repo = load_repo_artificial(include_hyperparameters=True)
save_path = "tmp_repo"
repo.to_dir(path=save_path)
repo_loaded = EvaluationRepository.from_dir(path=save_path)
verify_equivalent_repository(repo1=repo, repo2=repo_loaded, verify_ensemble=True, exact=True)

repo_float64 = load_repo_artificial(include_hyperparameters=True, dtype=np.float64)
save_path = "tmp_repo_from_float64"
repo_float64.to_dir(path=save_path)
repo_loaded_float64 = EvaluationRepository.from_dir(path=save_path)
# exact=False because the loaded version is float32 and the original is float64
verify_equivalent_repository(repo1=repo_float64, repo2=repo_loaded_float64, verify_ensemble=True, exact=False)


def _assert_predict_multi_binary_as_multiclass(repo, fun: Callable, dataset, configs, n_rows, n_classes):
problem_type = repo.dataset_info(dataset=dataset)["problem_type"]
predict_multi = fun(dataset=dataset, fold=2, configs=configs)
Expand Down

0 comments on commit 8b8e06c

Please sign in to comment.