diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 95e018a508..b70926b6ae 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -503,9 +503,14 @@ def evaluate( # aggregate results ; run bootstrap CIs for task_output in eval_tasks: task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) - results, samples, configs, versions, num_fewshot = consolidate_results( - eval_tasks - ) + ( + results, + samples, + configs, + versions, + num_fewshot, + higher_is_better, + ) = consolidate_results(eval_tasks) ### Calculate group metrics ### if bool(results): @@ -516,6 +521,27 @@ def evaluate( # or `task_name: []`. # we only want to operate on groups here. continue + + # collect all higher_is_better values for metrics + # in the group's subtasks. + # TODO: clean this up ; unify with the below metric_list loop? + _higher_is_better = {} + for task in task_list: + for m, h in higher_is_better[task].items(): + if m not in _higher_is_better.keys(): + _higher_is_better[m] = h + if ( + m in _higher_is_better + and _higher_is_better[m] is not None + and _higher_is_better[m] != h + ): + eval_logger.warning( + f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None." + ) + _higher_is_better[m] = None + higher_is_better[group] = _higher_is_better + + # collect all metric keys used by a subtask in the group. metric_list = list( { key @@ -591,6 +617,7 @@ def evaluate( "configs": dict(sorted(configs.items())), "versions": dict(sorted(versions.items())), "n-shot": dict(sorted(num_fewshot.items())), + "higher_is_better": dict(sorted(higher_is_better.items())), "n-samples": { task_output.task_name: { "original": len(task_output.task.eval_docs), diff --git a/lm_eval/evaluator_utils.py b/lm_eval/evaluator_utils.py index 4eb3d94ff5..82197b5d57 100644 --- a/lm_eval/evaluator_utils.py +++ b/lm_eval/evaluator_utils.py @@ -253,6 +253,9 @@ def consolidate_results( configs = collections.defaultdict(dict) # Tracks each task's version. versions = collections.defaultdict(dict) + # Track `higher_is_better` for each metric + higher_is_better = collections.defaultdict(dict) + for task_output in eval_tasks: if "task_alias" in (task_config := task_output.task_config): results[task_output.task_name]["alias"] = task_config["task_alias"] @@ -263,6 +266,7 @@ def consolidate_results( configs[task_output.task_name] = task_output.task_config versions[task_output.task_name] = task_output.version samples[task_output.task_name] = task_output.logged_samples + higher_is_better[task_output.task_name] = task_output.task.higher_is_better() for (metric, filter_key), items in task_output.sample_metrics.items(): metric_key = f"{metric},{filter_key}" results[task_output.task_name][metric_key] = task_output.agg_metrics[ @@ -272,7 +276,7 @@ def consolidate_results( results[task_output.task_name][ f"{metric}_stderr,{filter_key}" ] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] - return results, samples, configs, versions, num_fewshot + return results, samples, configs, versions, num_fewshot, higher_is_better @positional_deprecated diff --git a/lm_eval/utils.py b/lm_eval/utils.py index 2e71e38c99..54de16dd7a 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -26,6 +26,11 @@ SPACING = " " * 47 +HIGHER_IS_BETTER_SYMBOLS = { + True: "↑", + False: "↓", +} + def hash_string(string: str) -> str: return hashlib.sha256(string.encode("utf-8")).hexdigest() @@ -257,6 +262,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): "Filter", "n-shot", "Metric", + "", "Value", "", "Stderr", @@ -277,6 +283,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): dic = result_dict[column][k] version = result_dict["versions"].get(k, "N/A") n = str(result_dict["n-shot"][k]) + higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) if "alias" in dic: k = dic.pop("alias") @@ -286,13 +293,15 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): if m.endswith("_stderr"): continue + hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") + if m + "_stderr" + "," + f in dic: se = dic[m + "_stderr" + "," + f] if se != "N/A": se = "%.4f" % se - values.append([k, version, f, n, m, "%.4f" % v, "±", se]) + values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se]) else: - values.append([k, version, f, n, m, "%.4f" % v, "", ""]) + values.append([k, version, f, n, m, hib, "%.4f" % v, "", ""]) k = "" version = "" md_writer.value_matrix = values