Skip to content

Commit

Permalink
Add HumanEval (#1992)
Browse files Browse the repository at this point in the history
* add custom filter

* fix type casting of references

* add humaneval

* fix a bug in humaneval

* add greedy version of humaneval

* update tasks README

* test humaneval

* return multiple metrics

* nit

* add confirmation to run code tasks

* nit

* nit

---------

Co-authored-by: Hojin Lee <[email protected]>
Co-authored-by: Baber <[email protected]>
  • Loading branch information
3 people authored Jan 15, 2025
1 parent bb098f1 commit 4c11206
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 8 deletions.
6 changes: 6 additions & 0 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
return parser


Expand Down Expand Up @@ -404,6 +409,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
**request_caching_args,
)

Expand Down
15 changes: 11 additions & 4 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TaskConfig(dict):
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
Expand Down Expand Up @@ -732,6 +733,9 @@ def __init__(
# mark the task as requiring multimodality.
self.MULTIMODAL = True

if self.config.unsafe_code is not False:
self.UNSAFE_CODE = True

if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path

Expand Down Expand Up @@ -1503,9 +1507,9 @@ def process_results(self, doc, results):
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
elif (
type(gold) is not type(result)
and "bypass" not in self._metric_fn_list.keys()
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
Expand Down Expand Up @@ -1561,7 +1565,10 @@ def process_results(self, doc, results):
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
# This allows for multiple metrics to be returned from the same function
for k, v in result_score.items():
result_dict[k] = v
return result_dict
result_dict[metric] = result_score
else:
raise ValueError(
Expand Down
21 changes: 19 additions & 2 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def simple_evaluate(
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
):
"""Instantiate and evaluate a model on a list of tasks.
Expand Down Expand Up @@ -313,6 +314,7 @@ def _adjust_config(task_dict):
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
)

if lm.rank == 0:
Expand Down Expand Up @@ -372,6 +374,7 @@ def evaluate(
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
verbosity: str = "INFO",
confirm_run_unsafe_code: bool = False,
):
"""Instantiate and evaluate a model on a list of tasks.
Expand All @@ -381,6 +384,10 @@ def evaluate(
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests.
:param rewrite_requests_cache: bool, optional
Rewrites all the request cache if set to `True`.
:param bootstrap_iters:
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
:param write_out: bool
Expand All @@ -396,6 +403,10 @@ def evaluate(
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param verbosity: str
Verbosity level for logging
:param confirm_run_unsafe_code: bool
Whether to confirm running tasks marked as unsafe.
:return
Dictionary of results
"""
Expand All @@ -422,13 +433,19 @@ def evaluate(
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")

# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
# validation checks:
# 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
# 2.are we running code that is marked as unsafe.
incompatible_tasks = []
for task_output in eval_tasks:
task: Task = task_output.task

if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
raise ValueError(
f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
Expand All @@ -438,7 +455,7 @@ def evaluate(
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check
# end validation check

# Cache the limit arg.
limit_arg = limit
Expand Down
8 changes: 7 additions & 1 deletion lm_eval/evaluator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
mean,
pooled_sample_stderr,
stderr_for_metric,
)
Expand Down Expand Up @@ -99,7 +100,12 @@ def from_taskdict(cls, task_name: str, task):

def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
for (metric, filter_key), items in self.sample_metrics.items():
agg_fn = self.task.aggregation()[metric]
try:
agg_fn = self.task.aggregation()[metric]
except KeyError:
# This is when process results output an arbitrary metric
# TODO: Handle this better and allow other aggregate functions other than mean.
agg_fn = mean
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter

from . import extraction, selection, transformation
from . import custom, extraction, selection, transformation


def build_filter_ensemble(
Expand Down
17 changes: 17 additions & 0 deletions lm_eval/filters/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter


@register_filter("custom")
class CustomFilter(Filter):
"""
Custom filter that applies a custom, user-defined function to the model responses.
"""

def __init__(self, **kwargs) -> None:
self.filter_fn = kwargs.pop("filter_fn")

super().__init__(**kwargs)

def apply(self, resps, docs):
return self.filter_fn(resps, docs)
1 change: 1 addition & 0 deletions lm_eval/tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
| [hellaswag](hellaswag/README.md) | Tasks to predict the ending of stories or scenarios, testing comprehension and creativity. | English |
| [hendrycks_ethics](hendrycks_ethics/README.md) | Tasks designed to evaluate the ethical reasoning capabilities of models. | English |
| [hendrycks_math](hendrycks_math/README.md) | Mathematical problem-solving tasks to test numerical reasoning and problem-solving. | English |
| [humaneval](humaneval/README.md) | Code generation task that measure functional correctness for synthesizing programs from docstrings. | Python |
| [ifeval](ifeval/README.md) | Interactive fiction evaluation tasks for narrative understanding and reasoning. | English |
| [inverse_scaling](inverse_scaling/README.md) | Multiple-choice tasks from the Inverse Scaling Prize, designed to find settings where larger language models perform worse. | English |
| [japanese_leaderboard](japanese_leaderboard/README.md) | Japanese language understanding tasks to benchmark model performance on various linguistic aspects. | Japanese |
Expand Down
46 changes: 46 additions & 0 deletions lm_eval/tasks/humaneval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# HumanEval

## Paper
Evaluating Large Language Models Trained on Code
https://arxiv.org/abs/2107.03374

We introduce Codex, a GPT language model fine-tuned on publicly available code from GitHub, and study its Python code-writing capabilities. A distinct production version of Codex powers GitHub Copilot. On HumanEval, a new evaluation set we release to measure functional correctness for synthesizing programs from docstrings, our model solves 28.8% of the problems, while GPT-3 solves 0% and GPT-J solves 11.4%. Furthermore, we find that repeated sampling from the model is a surprisingly effective strategy for producing working solutions to difficult prompts. Using this method, we solve 70.2% of our problems with 100 samples per problem. Careful investigation of our model reveals its limitations, including difficulty with docstrings describing long chains of operations and with binding operations to variables. Finally, we discuss the potential broader impacts of deploying powerful code generation technologies, covering safety, security, and economics.

Homepage: https://github.com/openai/human-eval


## Citation
```
@article{chen2021codex,
title={Evaluating Large Language Models Trained on Code},
author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba},
year={2021},
eprint={2107.03374},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```

### Groups and Tasks

#### Groups

* Not part of a group yet.

#### Tasks

- `humaneval` pass@1
- `humaneval_64` pass@64 variant

### Checklist

For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?


If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
30 changes: 30 additions & 0 deletions lm_eval/tasks/humaneval/humaneval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
task: humaneval
dataset_path: openai/openai_humaneval
unsafe_code: true
output_type: generate_until
test_split: test
doc_to_text: "{{prompt}}"
doc_to_target: "{{test}}\ncheck({{entry_point}})"
metric_list:
- metric: !function utils.pass_at_k
aggregation: mean
higher_is_better: true
k: [1]
generation_kwargs:
until:
- "\nclass"
- "\ndef"
- "\n#"
- "\nif"
- "\nprint"
max_gen_toks: 1024
do_sample: false
repeats: 1
num_fewshot: 0
filter_list:
- name: "create_test"
filter:
- function: "custom"
filter_fn: !function utils.build_predictions
metadata:
version: 1.0
19 changes: 19 additions & 0 deletions lm_eval/tasks/humaneval/humaneval_64.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
include: humaneval.yaml
task: humaneval_64
repeats: 64
metric_list:
- metric: !function utils.pass_at_k
aggregation: mean
higher_is_better: true
k: [2,8,16,32,64]
generation_kwargs:
until:
- "\nclass"
- "\ndef"
- "\n#"
- "\nif"
- "\nprint"
max_gen_toks: 1024
do_sample: true
temperature: 0.2
top_p: 0.95
27 changes: 27 additions & 0 deletions lm_eval/tasks/humaneval/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import evaluate as hf_evaluate


try:
compute_ = hf_evaluate.load("code_eval")
test_cases = ["assert add(2, 3)==5"]
candidates = [["def add(a,b): return a*b"]]
results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
except Exception as e:
raise e


def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
global compute_
assert k is not None
if isinstance(k, int):
k = [k]
res = compute_.compute(
references=references,
predictions=predictions,
k=k,
)
return res[0]


def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)]

0 comments on commit 4c11206

Please sign in to comment.