Skip to content

Commit

Permalink
Update repo.evaluate_ensemble to return DataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma committed Oct 11, 2024
1 parent 6024700 commit e38a3e7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
28 changes: 16 additions & 12 deletions tabrepo/repository/ensemble_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def evaluate_ensemble(
rank: bool = True,
folds: list[int] | None = None,
backend: str = "ray",
) -> Tuple[pd.Series, pd.DataFrame]:
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
:param datasets: list of datasets to compute errors on.
:param configs: list of config to consider for ensembling. Uses all configs if None.
Expand Down Expand Up @@ -58,22 +58,29 @@ def evaluate_ensemble(
backend=backend,
)

dataset_folds = [(self.task_to_dataset(task=task), self.task_to_fold(task=task)) for task in tasks]
dict_errors, dict_ensemble_weights = scorer.compute_errors(configs=configs)
metric_error_list = [dict_errors[task] for task in tasks]
datasets_info = self.datasets_info(datasets=datasets)
dataset_metric_list = [datasets_info.loc[d]["metric"] for d, f in dataset_folds]
problem_type_list = [datasets_info.loc[d]["problem_type"] for d, f in dataset_folds]

output_dict = {
"metric_error": metric_error_list,
"metric": dataset_metric_list,
"problem_type": problem_type_list,
}

if rank:
dict_scores = scorer.compute_ranks(errors=dict_errors)
out = dict_scores
else:
out = dict_errors
dict_ranks = scorer.compute_ranks(errors=dict_errors)
rank_list = [dict_ranks[task] for task in tasks]
output_dict["rank"] = rank_list

dataset_folds = [(self.task_to_dataset(task=task), self.task_to_fold(task=task)) for task in tasks]
ensemble_weights = [dict_ensemble_weights[task] for task in tasks]
out_list = [out[task] for task in tasks]

multiindex = pd.MultiIndex.from_tuples(dataset_folds, names=["dataset", "fold"])

df_name = "rank" if rank else "metric_error"
df_out = pd.Series(data=out_list, index=multiindex, name=df_name)
df_out = pd.DataFrame(data=output_dict, index=multiindex)
df_ensemble_weights = pd.DataFrame(data=ensemble_weights, index=multiindex, columns=configs)

return df_out, df_ensemble_weights
Expand Down Expand Up @@ -161,11 +168,8 @@ def evaluate_ensemble_with_time(
time_infer_s = sum(latencies.values())

task_time_map[(dataset, fold)] = {"time_train_s": time_train_s, "time_infer_s": time_infer_s}
df_out = df_out.to_frame()
df_task_time = pd.DataFrame(task_time_map).T
df_out[["time_train_s", "time_infer_s"]] = df_task_time
df_datasets_info = self.datasets_info(datasets=[dataset])
df_out = df_out.join(df_datasets_info, how="inner")
df_datasets_to_tids = self.datasets_to_tids(datasets=[dataset]).to_frame()
df_datasets_to_tids.index.name = "dataset"
df_out = df_out.join(df_datasets_to_tids, how="inner")
Expand Down
14 changes: 7 additions & 7 deletions tst/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_repository():
assert repo.labels_test(dataset=dataset, fold=2).shape == (13,)
assert repo.dataset_metadata(dataset=dataset) == {'dataset': dataset, 'task_type': 'TaskType.SUPERVISED_CLASSIFICATION'}
result_errors, result_ensemble_weights = repo.evaluate_ensemble(datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native")
assert result_errors.shape == (3,)
assert result_errors.shape == (3, 4)
assert len(result_ensemble_weights) == 3

dataset_info = repo.dataset_info(dataset=dataset)
Expand All @@ -76,12 +76,12 @@ def test_repository():
result_errors_w_max_models, result_ensemble_weights_w_max_models = repo.evaluate_ensemble(
datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native", ensemble_kwargs={"max_models_per_type": 1}
)
assert result_errors_w_max_models.shape == (3,)
assert result_errors_w_max_models.shape == (3, 4)
assert len(result_ensemble_weights_w_max_models) == 3
assert np.allclose(result_ensemble_weights_w_max_models.loc[(dataset, 0)], [1.0, 0.0])

assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config],
ensemble_size=5, folds=[2], backend="native")[0].shape == (1,)
ensemble_size=5, folds=[2], backend="native")[0].shape == (1, 4)

repo: EvaluationRepository = repo.subset(folds=[0, 2])
assert repo.datasets() == ['abalone', 'ada']
Expand All @@ -93,9 +93,9 @@ def test_repository():
assert repo.predict_test(dataset=dataset, config=config, fold=2).shape == (13, 25)
assert repo.dataset_metadata(dataset=dataset) == {'dataset': dataset, 'task_type': 'TaskType.SUPERVISED_CLASSIFICATION'}
# result_errors, result_ensemble_weights = repo.evaluate_ensemble(datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native")[0],
assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native")[0].shape == (2,)
assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native")[0].shape == (2, 4)
assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config],
ensemble_size=5, folds=[2], backend="native")[0].shape == (1,)
ensemble_size=5, folds=[2], backend="native")[0].shape == (1, 4)

repo: EvaluationRepository = repo.subset(folds=[2], datasets=[dataset], configs=[config])
assert repo.datasets() == ['abalone']
Expand All @@ -106,10 +106,10 @@ def test_repository():
assert repo.predict_val(dataset=dataset, config=config, fold=2).shape == (123, 25)
assert repo.predict_test(dataset=dataset, config=config, fold=2).shape == (13, 25)
assert repo.dataset_metadata(dataset=dataset) == {'dataset': dataset, 'task_type': 'TaskType.SUPERVISED_CLASSIFICATION'}
assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native")[0].shape == (1,)
assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config], ensemble_size=5, backend="native")[0].shape == (1, 4)

assert repo.evaluate_ensemble(datasets=[dataset], configs=[config, config],
ensemble_size=5, folds=[2], backend="native")[0].shape == (1,)
ensemble_size=5, folds=[2], backend="native")[0].shape == (1, 4)


def test_repository_force_to_dense():
Expand Down

0 comments on commit e38a3e7

Please sign in to comment.