Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

higher_is_better tickers in output table #1893

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,14 @@ def evaluate(
# aggregate results ; run bootstrap CIs
for task_output in eval_tasks:
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
results, samples, configs, versions, num_fewshot = consolidate_results(
eval_tasks
)
(
results,
samples,
configs,
versions,
num_fewshot,
higher_is_better,
) = consolidate_results(eval_tasks)

### Calculate group metrics ###
if bool(results):
Expand All @@ -516,6 +521,27 @@ def evaluate(
# or `task_name: []`.
# we only want to operate on groups here.
continue

# collect all higher_is_better values for metrics
# in the group's subtasks.
# TODO: clean this up ; unify with the below metric_list loop?
_higher_is_better = {}
for task in task_list:
for m, h in higher_is_better[task].items():
if m not in _higher_is_better.keys():
_higher_is_better[m] = h
if (
m in _higher_is_better
and _higher_is_better[m] is not None
and _higher_is_better[m] != h
):
eval_logger.warning(
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
)
_higher_is_better[m] = None
higher_is_better[group] = _higher_is_better

# collect all metric keys used by a subtask in the group.
metric_list = list(
{
key
Expand Down Expand Up @@ -591,6 +617,7 @@ def evaluate(
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
"higher_is_better": dict(sorted(higher_is_better.items())),
"n-samples": {
task_output.task_name: {
"original": len(task_output.task.eval_docs),
Expand Down
6 changes: 5 additions & 1 deletion lm_eval/evaluator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def consolidate_results(
configs = collections.defaultdict(dict)
# Tracks each task's version.
versions = collections.defaultdict(dict)
# Track `higher_is_better` for each metric
higher_is_better = collections.defaultdict(dict)

for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"]
Expand All @@ -263,6 +266,7 @@ def consolidate_results(
configs[task_output.task_name] = task_output.task_config
versions[task_output.task_name] = task_output.version
samples[task_output.task_name] = task_output.logged_samples
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}"
results[task_output.task_name][metric_key] = task_output.agg_metrics[
Expand All @@ -272,7 +276,7 @@ def consolidate_results(
results[task_output.task_name][
f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot
return results, samples, configs, versions, num_fewshot, higher_is_better


@positional_deprecated
Expand Down
13 changes: 11 additions & 2 deletions lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@

SPACING = " " * 47

HIGHER_IS_BETTER_SYMBOLS = {
True: "↑",
False: "↓",
}


def hash_string(string: str) -> str:
return hashlib.sha256(string.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -257,6 +262,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
"Filter",
"n-shot",
"Metric",
"",
"Value",
"",
"Stderr",
Expand All @@ -277,6 +283,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
dic = result_dict[column][k]
version = result_dict["versions"].get(k, "N/A")
n = str(result_dict["n-shot"][k])
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})

if "alias" in dic:
k = dic.pop("alias")
Expand All @@ -286,13 +293,15 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
if m.endswith("_stderr"):
continue

hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")

if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
if se != "N/A":
se = "%.4f" % se
values.append([k, version, f, n, m, "%.4f" % v, "±", se])
values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se])
else:
values.append([k, version, f, n, m, "%.4f" % v, "", ""])
values.append([k, version, f, n, m, hib, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
Expand Down
Loading