-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for generative answering of multiple_choice tasks #2601
Open
pasky
wants to merge
4
commits into
EleutherAI:main
Choose a base branch
from
pasky:multiple-choice-generate
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
c225602
fix(zeno): Generate unique ids in case of multiple filters
pasky 0bd64c2
fix(zeno): Report even non-aggregable metrics, just not as metrics
pasky 5cca68f
Add a basic support for --multiple-choice-generate
pasky d9e49af
Add support for --multiple_choice_generate abcd
pasky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ | |
from lm_eval import utils | ||
from lm_eval.api import samplers | ||
from lm_eval.api.instance import Instance, OutputType | ||
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity | ||
from lm_eval.api.metrics import bits_per_byte, exact_match_fn, mean, weighted_perplexity | ||
from lm_eval.api.registry import ( | ||
AGGREGATION_REGISTRY, | ||
DEFAULT_METRIC_REGISTRY, | ||
|
@@ -80,6 +80,8 @@ class TaskConfig(dict): | |
use_prompt: Optional[str] = None | ||
description: str = "" | ||
target_delimiter: str = " " | ||
choice_delimiter: str = " / " | ||
option_delimiter: str = "\n" | ||
fewshot_delimiter: str = "\n\n" | ||
fewshot_config: Optional[dict] = None | ||
# runtime configuration options | ||
|
@@ -111,16 +113,15 @@ def __post_init__(self) -> None: | |
if "until" not in self.generation_kwargs: | ||
self.generation_kwargs["until"] = [self.fewshot_delimiter] | ||
else: | ||
if self.output_type == "generate_until": | ||
# ensure that we greedily generate in absence of explicit arguments otherwise | ||
self.generation_kwargs = { | ||
"until": ( | ||
None | ||
if self.fewshot_delimiter is None | ||
else [self.fewshot_delimiter] | ||
), | ||
"do_sample": False, | ||
} | ||
# ensure that we greedily generate in absence of explicit arguments otherwise | ||
self.generation_kwargs = { | ||
"until": ( | ||
None | ||
if self.fewshot_delimiter is None | ||
else [self.fewshot_delimiter] | ||
), | ||
"do_sample": False, | ||
} | ||
|
||
def __getitem__(self, item): | ||
return getattr(self, item) | ||
|
@@ -380,6 +381,7 @@ def build_all_requests( | |
system_instruction: Optional[str] = None, | ||
apply_chat_template: bool = False, | ||
fewshot_as_multiturn: bool = False, | ||
multiple_choice_generate: Union[bool, str] = False, | ||
chat_template: Optional[Callable] = None, | ||
tokenizer_name: str = "", | ||
) -> None: | ||
|
@@ -391,6 +393,7 @@ def build_all_requests( | |
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" | ||
cache_key += "-chat_template" if apply_chat_template else "" | ||
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" | ||
cache_key += "-multiple_choice_generate" if multiple_choice_generate else "" | ||
cache_key += ( | ||
f"-system_prompt_hash{utils.hash_string(system_instruction)}" | ||
if system_instruction is not None | ||
|
@@ -435,12 +438,22 @@ def build_all_requests( | |
total=num_docs, | ||
): | ||
# sample fewshot context #TODO: need to offset doc_id by rank now! | ||
doc_system_instruction = system_instruction or "" | ||
if self.OUTPUT_TYPE == "multiple_choice" and multiple_choice_generate: | ||
if doc_system_instruction: | ||
doc_system_instruction += " " | ||
if multiple_choice_generate == "abcd": | ||
doc_system_instruction += "Please include \"ANSWER: <letter>\" in your response with the letter of the correct last answer." | ||
else: | ||
doc_system_instruction += "Please answer with the letter of the correct last answer." | ||
Comment on lines
+446
to
+448
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about non-english tasks that are already inside this repo? |
||
|
||
fewshot_ctx = self.fewshot_context( | ||
doc, | ||
0 if self.config.num_fewshot is None else self.config.num_fewshot, | ||
system_instruction, | ||
doc_system_instruction, | ||
apply_chat_template, | ||
fewshot_as_multiturn, | ||
multiple_choice_generate, | ||
chat_template, | ||
) | ||
|
||
|
@@ -450,6 +463,7 @@ def build_all_requests( | |
ctx=fewshot_ctx, | ||
metadata=(self.config["task"], doc_id, self.config.repeats), | ||
apply_chat_template=apply_chat_template, | ||
multiple_choice_generate=multiple_choice_generate, | ||
) | ||
|
||
if not isinstance(inst, list): | ||
|
@@ -1024,6 +1038,7 @@ def fewshot_context( | |
system_instruction: Optional[str] = None, | ||
apply_chat_template: bool = False, | ||
fewshot_as_multiturn: bool = False, | ||
multiple_choice_generate: Union[bool, str] = False, | ||
chat_template: Optional[Callable] = None, | ||
) -> str: | ||
"""Returns a fewshot context string that is made up of a prepended description | ||
|
@@ -1039,6 +1054,8 @@ def fewshot_context( | |
Whether to apply the chat template to the fewshot context. | ||
:param fewshot_as_multiturn: bool | ||
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. | ||
:param multiple_choice_generate: Union[bool, str] | ||
Whether to generate multiple choice answer from scratch rather than pick by logprobs. | ||
:param chat_template: | ||
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. | ||
:returns: str | ||
|
@@ -1085,6 +1102,17 @@ def fewshot_context( | |
labeled_examples += self.sampler.get_context(doc, num_fewshot) | ||
|
||
example = self.doc_to_text(doc) | ||
if self.config.doc_to_choice is not None and multiple_choice_generate: | ||
if not isinstance(example, str): | ||
raise NotImplementedError("--multiple_choice_generate is implemented only for simple text docs") | ||
if multiple_choice_generate == "abcd": | ||
choices = self.doc_to_choice(doc) | ||
for label, choice in zip(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")[:len(choices)], choices): | ||
example += f"{self.config.option_delimiter}({label}) {choice}" | ||
else: | ||
example += self.config.target_delimiter | ||
example += "(" + self.config.choice_delimiter.join(self.doc_to_choice(doc)) + ")" | ||
|
||
if apply_chat_template: | ||
if self.multiple_input: | ||
return chat_template(labeled_examples) | ||
|
@@ -1300,17 +1328,24 @@ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: | |
return None | ||
|
||
def construct_requests( | ||
self, doc: dict, ctx: str, **kwargs | ||
self, doc: dict, ctx: str, multiple_choice_generate: Union[bool, str], **kwargs | ||
) -> Union[List[Instance], Instance]: | ||
apply_chat_template = kwargs.pop("apply_chat_template", False) | ||
|
||
aux_arguments = None | ||
|
||
if self.OUTPUT_TYPE == "loglikelihood": | ||
self.multiple_choice_generate = multiple_choice_generate | ||
output_type = self.OUTPUT_TYPE | ||
if output_type == "multiple_choice" and multiple_choice_generate: | ||
output_type = "generate_until" | ||
if self.multiple_input: | ||
raise NotImplementedError("The \"multiple input\" mode of multiple_choice tasks is not implemented for --multiple_choice_generate.") | ||
|
||
if output_type == "loglikelihood": | ||
arguments = (ctx, self.doc_to_target(doc)) | ||
elif self.OUTPUT_TYPE == "loglikelihood_rolling": | ||
elif output_type == "loglikelihood_rolling": | ||
arguments = (self.doc_to_target(doc),) | ||
elif self.OUTPUT_TYPE == "multiple_choice": | ||
elif output_type == "multiple_choice": | ||
choices = self.doc_to_choice(doc) | ||
target_delimiter = self.config.target_delimiter | ||
if apply_chat_template: | ||
|
@@ -1337,7 +1372,7 @@ def construct_requests( | |
|
||
arguments.extend(aux_arguments) | ||
|
||
elif self.OUTPUT_TYPE == "generate_until": | ||
elif output_type == "generate_until": | ||
arguments = (ctx, deepcopy(self.config.generation_kwargs)) | ||
|
||
multimodal_arg = {} | ||
|
@@ -1355,7 +1390,7 @@ def construct_requests( | |
else: | ||
arguments = arguments + (multimodal_arg,) | ||
|
||
if self.OUTPUT_TYPE == "multiple_choice": | ||
if output_type == "multiple_choice": | ||
request_list = [ | ||
Instance( | ||
request_type="loglikelihood", | ||
|
@@ -1370,7 +1405,7 @@ def construct_requests( | |
return request_list | ||
|
||
return Instance( | ||
request_type=self.OUTPUT_TYPE, | ||
request_type=output_type, | ||
doc=doc, | ||
arguments=arguments, | ||
idx=0, | ||
|
@@ -1411,7 +1446,7 @@ def process_results(self, doc, results): | |
else {} | ||
), | ||
} | ||
elif self.OUTPUT_TYPE == "multiple_choice": | ||
elif self.OUTPUT_TYPE == "multiple_choice" and not self.multiple_choice_generate: | ||
lls, is_greedy = zip(*results) | ||
|
||
# retrieve choices in List[str] form, to compute choice lengths, etc. | ||
|
@@ -1492,14 +1527,22 @@ def process_results(self, doc, results): | |
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 | ||
result_dict["acc_mutual_info"] = acc_mutual_info | ||
|
||
elif self.OUTPUT_TYPE == "generate_until": | ||
elif self.OUTPUT_TYPE == "generate_until" or (self.OUTPUT_TYPE == "multiple_choice" and self.multiple_choice_generate): | ||
gold = self.doc_to_target(doc) | ||
result = results[0] | ||
if self.config.doc_to_choice is not None: | ||
# If you set doc_to_choice, | ||
# it assumes that doc_to_target returns a number. | ||
choices = self.doc_to_choice(doc) | ||
gold = choices[gold] | ||
if self.multiple_choice_generate == "abcd": | ||
try: | ||
result_label = re.findall(r"ANSWER: ([A-Z])", result)[-1] | ||
result_i = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ").index(result_label) | ||
result = choices[result_i] | ||
except (AttributeError, ValueError, IndexError): | ||
eval_logger.warning(f"[{self}] LLM did not pick a valid result ('{result}')") | ||
result = choices[0] # XXX guess "randomly" | ||
# we expect multiple_targets to be a list. | ||
elif self.multiple_target: | ||
gold = list(gold) | ||
|
@@ -1511,6 +1554,12 @@ def process_results(self, doc, results): | |
gold = type(result)(gold) | ||
|
||
for metric in self._metric_fn_list.keys(): | ||
metric_fn = self._metric_fn_list[metric] | ||
metric_result_key = metric | ||
if self.OUTPUT_TYPE == "multiple_choice" and self.multiple_choice_generate: | ||
metric_fn = exact_match_fn | ||
metric_result_key = "exact_match" | ||
|
||
if self.multiple_target: | ||
# in the case where we have multiple targets, | ||
# return true if any are true | ||
|
@@ -1522,7 +1571,7 @@ def process_results(self, doc, results): | |
gold = [gold] | ||
if metric == "exact_match": | ||
result = [result for _ in range(len(gold))] | ||
scores = self._metric_fn_list[metric]( | ||
scores = metric_fn( | ||
references=gold, | ||
predictions=result, | ||
**self._metric_fn_kwargs[metric], | ||
|
@@ -1531,15 +1580,15 @@ def process_results(self, doc, results): | |
else: | ||
for gold_option in gold: | ||
try: | ||
result_score = self._metric_fn_list[metric]( | ||
result_score = metric_fn( | ||
references=[gold_option], | ||
predictions=[result], | ||
**self._metric_fn_kwargs[metric], | ||
) | ||
except ( | ||
TypeError | ||
): # TODO: this is hacky and I don't want to do it | ||
result_score = self._metric_fn_list[metric]( | ||
result_score = metric_fn( | ||
[gold_option, result] | ||
) | ||
if isinstance(result_score, dict): | ||
|
@@ -1552,16 +1601,16 @@ def process_results(self, doc, results): | |
result_score = 0.0 | ||
else: | ||
try: | ||
result_score = self._metric_fn_list[metric]( | ||
result_score = metric_fn( | ||
references=[gold], | ||
predictions=[result], | ||
**self._metric_fn_kwargs[metric], | ||
) | ||
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics | ||
result_score = self._metric_fn_list[metric]([gold, result]) | ||
result_score = metric_fn([gold, result]) | ||
if isinstance(result_score, dict): | ||
# TODO: this handles the case where HF evaluate returns a dict. | ||
result_score = result_score[metric] | ||
result_score = result_score[metric_result_key] | ||
result_dict[metric] = result_score | ||
else: | ||
raise ValueError( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I suggest to not hardcode these. What if doc_system_instruction supposed to be delimited with some other delimiter? What if set of choices is not 4 letters, not these 4 letters, or not letters at all? This framework supports external tasks and also have multiple forks already, so there may be (I am not using "are" because of no intention to google proof of this idea) multiple choice tasks set up differently than "abcd".