diff --git a/lm_eval/tasks/README.md b/lm_eval/tasks/README.md index 17e2f9b2e4..8a9363a90b 100644 --- a/lm_eval/tasks/README.md +++ b/lm_eval/tasks/README.md @@ -80,6 +80,7 @@ | medqa | Multiple choice question answering based on the United States Medical License Exams. | | | [mgsm](mgsm/README.md) | Benchmark of multilingual grade-school math problems. | Spanish, French, German, Russian, Chinese, Japanese, Thai, Swahili, Bengali, Telugu | | [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English | +| [mlqa](mlqa/README.md) | MultiLingual Question Answering benchmark dataset for evaluating cross-lingual question answering performance. | English, Arabic, German, Spanish, Hindi, Vietnamese, Simplified Chinese | | [mmlu](mmlu/README.md) | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English | | [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English | | [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English | diff --git a/lm_eval/tasks/mlqa/README.md b/lm_eval/tasks/mlqa/README.md new file mode 100644 index 0000000000..3d82f95ff0 --- /dev/null +++ b/lm_eval/tasks/mlqa/README.md @@ -0,0 +1,101 @@ +# MLQA + +### Paper + +Title: `MLQA: Evaluating Cross-lingual Extractive Question Answering` + +Abstract: `https://arxiv.org/abs/1910.07475` + +MLQA (MultiLingual Question Answering) is a benchmark dataset for evaluating cross-lingual question answering performance. +MLQA consists of over 5K extractive QA instances (12K in English) in SQuAD format in seven languages - English, Arabic, +German, Spanish, Hindi, Vietnamese and Simplified Chinese. MLQA is highly parallel, with QA instances parallel between +4 different languages on average + +Homepage: `https://github.com/facebookresearch/MLQA` + + +### Citation + +``` +@misc{lewis2020mlqaevaluatingcrosslingualextractive, + title={MLQA: Evaluating Cross-lingual Extractive Question Answering}, + author={Patrick Lewis and Barlas Oğuz and Ruty Rinott and Sebastian Riedel and Holger Schwenk}, + year={2020}, + eprint={1910.07475}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/1910.07475}, +} +``` + +### Groups, Tags, and Tasks + +#### Groups + +* Not part of a group yet + +#### Tasks + +Tasks of the form `mlqa_context-lang_question-lang.yaml` +* `mlqa_ar_ar.yaml` +* `mlqa_ar_de.yaml` +* `mlqa_ar_vi.yaml` +* `mlqa_ar_zh.yaml` +* `mlqa_ar_en.yaml` +* `mlqa_ar_es.yaml` +* `mlqa_ar_hi.yaml` +* `mlqa_de_ar.yaml` +* `mlqa_de_de.yaml` +* `mlqa_de_vi.yaml` +* `mlqa_de_zh.yaml` +* `mlqa_de_en.yaml` +* `mlqa_de_es.yaml` +* `mlqa_de_hi.yaml` +* `mlqa_vi_ar.yaml` +* `mlqa_vi_de.yaml` +* `mlqa_vi_vi.yaml` +* `mlqa_vi_zh.yaml` +* `mlqa_vi_en.yaml` +* `mlqa_vi_es.yaml` +* `mlqa_vi_hi.yaml` +* `mlqa_zh_ar.yaml` +* `mlqa_zh_de.yaml` +* `mlqa_zh_vi.yaml` +* `mlqa_zh_zh.yaml` +* `mlqa_zh_en.yaml` +* `mlqa_zh_es.yaml` +* `mlqa_zh_hi.yaml` +* `mlqa_en_ar.yaml` +* `mlqa_en_de.yaml` +* `mlqa_en_vi.yaml` +* `mlqa_en_zh.yaml` +* `mlqa_en_en.yaml` +* `mlqa_en_es.yaml` +* `mlqa_en_hi.yaml` +* `mlqa_es_ar.yaml` +* `mlqa_es_de.yaml` +* `mlqa_es_vi.yaml` +* `mlqa_es_zh.yaml` +* `mlqa_es_en.yaml` +* `mlqa_es_es.yaml` +* `mlqa_es_hi.yaml` +* `mlqa_hi_ar.yaml` +* `mlqa_hi_de.yaml` +* `mlqa_hi_vi.yaml` +* `mlqa_hi_zh.yaml` +* `mlqa_hi_en.yaml` +* `mlqa_hi_es.yaml` +* `mlqa_hi_hi.yaml` + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/lm_eval/tasks/mlqa/generate_tasks.py b/lm_eval/tasks/mlqa/generate_tasks.py new file mode 100644 index 0000000000..19bd3533af --- /dev/null +++ b/lm_eval/tasks/mlqa/generate_tasks.py @@ -0,0 +1,48 @@ +# ruff: noqa: E731, E741 +""" +Script to generate task YAMLs for the mlqa dataset. +Based on `tasks/bigbench/generate_tasks.py`. +""" + +from datasets import get_dataset_config_names + + +chosen_subtasks = [] + +language_dict = { + "en": "english", + "es": "spanish", + "hi": "hindi", + "vi": "vietnamese", + "de": "german", + "ar": "arabic", + "zh": "chinese", +} + + +def main() -> None: + configs = get_dataset_config_names("facebook/mlqa", trust_remote_code=True) + for config in configs: + if len(config.split(".")) == 2: + continue + else: + chosen_subtasks.append(config) + assert len(chosen_subtasks) == 49 + for task in chosen_subtasks: + file_name = f"{task.replace('.', '_')}.yaml" + context_lang = file_name.split("_")[1] + # Not using yaml to avoid tagging issues with !function + with open(file_name, "w", encoding="utf-8") as f: + f.write("# Generated by generate_tasks.py\n") + + # Manually writing the YAML-like content inside files to avoid tagging issues + f.write("include: mlqa_common_yaml\n") + f.write(f"task: {task.replace('.', '_')}\n") + f.write(f"dataset_name: {task}\n") + f.write( + f"process_results: !function utils.process_results_{context_lang}\n" + ) + + +if __name__ == "__main__": + main() diff --git a/lm_eval/tasks/mlqa/mlqa_ar_ar.yaml b/lm_eval/tasks/mlqa/mlqa_ar_ar.yaml new file mode 100644 index 0000000000..8db625acce --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_ar +dataset_name: mlqa.ar.ar +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_ar_de.yaml b/lm_eval/tasks/mlqa/mlqa_ar_de.yaml new file mode 100644 index 0000000000..3d1468a7bd --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_de +dataset_name: mlqa.ar.de +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_ar_en.yaml b/lm_eval/tasks/mlqa/mlqa_ar_en.yaml new file mode 100644 index 0000000000..18e763e8ac --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_en +dataset_name: mlqa.ar.en +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_ar_es.yaml b/lm_eval/tasks/mlqa/mlqa_ar_es.yaml new file mode 100644 index 0000000000..c93ef03ec0 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_es +dataset_name: mlqa.ar.es +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_ar_hi.yaml b/lm_eval/tasks/mlqa/mlqa_ar_hi.yaml new file mode 100644 index 0000000000..5abb023ccd --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_hi +dataset_name: mlqa.ar.hi +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_ar_vi.yaml b/lm_eval/tasks/mlqa/mlqa_ar_vi.yaml new file mode 100644 index 0000000000..54869c657d --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_vi +dataset_name: mlqa.ar.vi +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_ar_zh.yaml b/lm_eval/tasks/mlqa/mlqa_ar_zh.yaml new file mode 100644 index 0000000000..5236d6cb87 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_ar_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_ar_zh +dataset_name: mlqa.ar.zh +process_results: !function utils.process_results_ar diff --git a/lm_eval/tasks/mlqa/mlqa_common_yaml b/lm_eval/tasks/mlqa/mlqa_common_yaml new file mode 100644 index 0000000000..c52ecb8914 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_common_yaml @@ -0,0 +1,22 @@ +dataset_path: facebook/mlqa +dataset_kwargs: + trust_remote_code: true +test_split: test +validation_split: validation +output_type: generate_until +doc_to_text: "Context: {{context}}\n\nQuestion: {{question}}\n\nAnswer:" +doc_to_target: "{{answers}}" +process_docs: !function utils.process_docs +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + - metric: f1 + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "\n" + do_sample: false +metadata: + version: 0.0 diff --git a/lm_eval/tasks/mlqa/mlqa_de_ar.yaml b/lm_eval/tasks/mlqa/mlqa_de_ar.yaml new file mode 100644 index 0000000000..1090a58925 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_ar +dataset_name: mlqa.de.ar +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_de_de.yaml b/lm_eval/tasks/mlqa/mlqa_de_de.yaml new file mode 100644 index 0000000000..be465ab57a --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_de +dataset_name: mlqa.de.de +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_de_en.yaml b/lm_eval/tasks/mlqa/mlqa_de_en.yaml new file mode 100644 index 0000000000..55f2652ce4 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_en +dataset_name: mlqa.de.en +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_de_es.yaml b/lm_eval/tasks/mlqa/mlqa_de_es.yaml new file mode 100644 index 0000000000..d4f085e624 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_es +dataset_name: mlqa.de.es +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_de_hi.yaml b/lm_eval/tasks/mlqa/mlqa_de_hi.yaml new file mode 100644 index 0000000000..ff3bbc4286 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_hi +dataset_name: mlqa.de.hi +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_de_vi.yaml b/lm_eval/tasks/mlqa/mlqa_de_vi.yaml new file mode 100644 index 0000000000..fe61983b70 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_vi +dataset_name: mlqa.de.vi +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_de_zh.yaml b/lm_eval/tasks/mlqa/mlqa_de_zh.yaml new file mode 100644 index 0000000000..ee1855626f --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_de_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_de_zh +dataset_name: mlqa.de.zh +process_results: !function utils.process_results_de diff --git a/lm_eval/tasks/mlqa/mlqa_en_ar.yaml b/lm_eval/tasks/mlqa/mlqa_en_ar.yaml new file mode 100644 index 0000000000..a8c72d2694 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_ar +dataset_name: mlqa.en.ar +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_en_de.yaml b/lm_eval/tasks/mlqa/mlqa_en_de.yaml new file mode 100644 index 0000000000..b27e02ae6c --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_de +dataset_name: mlqa.en.de +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_en_en.yaml b/lm_eval/tasks/mlqa/mlqa_en_en.yaml new file mode 100644 index 0000000000..d15e222f7b --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_en +dataset_name: mlqa.en.en +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_en_es.yaml b/lm_eval/tasks/mlqa/mlqa_en_es.yaml new file mode 100644 index 0000000000..eddb728f02 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_es +dataset_name: mlqa.en.es +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_en_hi.yaml b/lm_eval/tasks/mlqa/mlqa_en_hi.yaml new file mode 100644 index 0000000000..7c2e38249a --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_hi +dataset_name: mlqa.en.hi +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_en_vi.yaml b/lm_eval/tasks/mlqa/mlqa_en_vi.yaml new file mode 100644 index 0000000000..1a2f635ea3 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_vi +dataset_name: mlqa.en.vi +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_en_zh.yaml b/lm_eval/tasks/mlqa/mlqa_en_zh.yaml new file mode 100644 index 0000000000..91336eba9a --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_en_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_en_zh +dataset_name: mlqa.en.zh +process_results: !function utils.process_results_en diff --git a/lm_eval/tasks/mlqa/mlqa_es_ar.yaml b/lm_eval/tasks/mlqa/mlqa_es_ar.yaml new file mode 100644 index 0000000000..9a24508cbd --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_ar +dataset_name: mlqa.es.ar +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_es_de.yaml b/lm_eval/tasks/mlqa/mlqa_es_de.yaml new file mode 100644 index 0000000000..9a40b2b695 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_de +dataset_name: mlqa.es.de +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_es_en.yaml b/lm_eval/tasks/mlqa/mlqa_es_en.yaml new file mode 100644 index 0000000000..660968c7fd --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_en +dataset_name: mlqa.es.en +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_es_es.yaml b/lm_eval/tasks/mlqa/mlqa_es_es.yaml new file mode 100644 index 0000000000..1232947b92 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_es +dataset_name: mlqa.es.es +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_es_hi.yaml b/lm_eval/tasks/mlqa/mlqa_es_hi.yaml new file mode 100644 index 0000000000..5502288925 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_hi +dataset_name: mlqa.es.hi +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_es_vi.yaml b/lm_eval/tasks/mlqa/mlqa_es_vi.yaml new file mode 100644 index 0000000000..0ea9027dec --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_vi +dataset_name: mlqa.es.vi +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_es_zh.yaml b/lm_eval/tasks/mlqa/mlqa_es_zh.yaml new file mode 100644 index 0000000000..caecd1b2d0 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_es_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_es_zh +dataset_name: mlqa.es.zh +process_results: !function utils.process_results_es diff --git a/lm_eval/tasks/mlqa/mlqa_hi_ar.yaml b/lm_eval/tasks/mlqa/mlqa_hi_ar.yaml new file mode 100644 index 0000000000..e4c4263a1d --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_ar +dataset_name: mlqa.hi.ar +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_hi_de.yaml b/lm_eval/tasks/mlqa/mlqa_hi_de.yaml new file mode 100644 index 0000000000..8069b5a07b --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_de +dataset_name: mlqa.hi.de +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_hi_en.yaml b/lm_eval/tasks/mlqa/mlqa_hi_en.yaml new file mode 100644 index 0000000000..d7a18067bc --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_en +dataset_name: mlqa.hi.en +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_hi_es.yaml b/lm_eval/tasks/mlqa/mlqa_hi_es.yaml new file mode 100644 index 0000000000..d152ad66dc --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_es +dataset_name: mlqa.hi.es +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_hi_hi.yaml b/lm_eval/tasks/mlqa/mlqa_hi_hi.yaml new file mode 100644 index 0000000000..1ce79e6bbe --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_hi +dataset_name: mlqa.hi.hi +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_hi_vi.yaml b/lm_eval/tasks/mlqa/mlqa_hi_vi.yaml new file mode 100644 index 0000000000..534d90f70d --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_vi +dataset_name: mlqa.hi.vi +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_hi_zh.yaml b/lm_eval/tasks/mlqa/mlqa_hi_zh.yaml new file mode 100644 index 0000000000..8432db492d --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_hi_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_hi_zh +dataset_name: mlqa.hi.zh +process_results: !function utils.process_results_hi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_ar.yaml b/lm_eval/tasks/mlqa/mlqa_vi_ar.yaml new file mode 100644 index 0000000000..c22c11cd06 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_ar +dataset_name: mlqa.vi.ar +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_de.yaml b/lm_eval/tasks/mlqa/mlqa_vi_de.yaml new file mode 100644 index 0000000000..948ac3ac36 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_de +dataset_name: mlqa.vi.de +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_en.yaml b/lm_eval/tasks/mlqa/mlqa_vi_en.yaml new file mode 100644 index 0000000000..0106867703 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_en +dataset_name: mlqa.vi.en +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_es.yaml b/lm_eval/tasks/mlqa/mlqa_vi_es.yaml new file mode 100644 index 0000000000..9ac62c1056 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_es +dataset_name: mlqa.vi.es +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_hi.yaml b/lm_eval/tasks/mlqa/mlqa_vi_hi.yaml new file mode 100644 index 0000000000..26b232a879 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_hi +dataset_name: mlqa.vi.hi +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_vi.yaml b/lm_eval/tasks/mlqa/mlqa_vi_vi.yaml new file mode 100644 index 0000000000..d8277d78eb --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_vi +dataset_name: mlqa.vi.vi +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_vi_zh.yaml b/lm_eval/tasks/mlqa/mlqa_vi_zh.yaml new file mode 100644 index 0000000000..7ecc6b9192 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_vi_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_vi_zh +dataset_name: mlqa.vi.zh +process_results: !function utils.process_results_vi diff --git a/lm_eval/tasks/mlqa/mlqa_zh_ar.yaml b/lm_eval/tasks/mlqa/mlqa_zh_ar.yaml new file mode 100644 index 0000000000..42c3713d5a --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_ar.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_ar +dataset_name: mlqa.zh.ar +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/mlqa_zh_de.yaml b/lm_eval/tasks/mlqa/mlqa_zh_de.yaml new file mode 100644 index 0000000000..cb5e4cb884 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_de.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_de +dataset_name: mlqa.zh.de +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/mlqa_zh_en.yaml b/lm_eval/tasks/mlqa/mlqa_zh_en.yaml new file mode 100644 index 0000000000..653f26aefa --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_en.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_en +dataset_name: mlqa.zh.en +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/mlqa_zh_es.yaml b/lm_eval/tasks/mlqa/mlqa_zh_es.yaml new file mode 100644 index 0000000000..c98203f76f --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_es.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_es +dataset_name: mlqa.zh.es +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/mlqa_zh_hi.yaml b/lm_eval/tasks/mlqa/mlqa_zh_hi.yaml new file mode 100644 index 0000000000..ed58f47f4d --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_hi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_hi +dataset_name: mlqa.zh.hi +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/mlqa_zh_vi.yaml b/lm_eval/tasks/mlqa/mlqa_zh_vi.yaml new file mode 100644 index 0000000000..7043676235 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_vi.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_vi +dataset_name: mlqa.zh.vi +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/mlqa_zh_zh.yaml b/lm_eval/tasks/mlqa/mlqa_zh_zh.yaml new file mode 100644 index 0000000000..792b5ee0c9 --- /dev/null +++ b/lm_eval/tasks/mlqa/mlqa_zh_zh.yaml @@ -0,0 +1,5 @@ +# Generated by generate_tasks.py +include: mlqa_common_yaml +task: mlqa_zh_zh +dataset_name: mlqa.zh.zh +process_results: !function utils.process_results_zh diff --git a/lm_eval/tasks/mlqa/utils.py b/lm_eval/tasks/mlqa/utils.py new file mode 100644 index 0000000000..61e593716a --- /dev/null +++ b/lm_eval/tasks/mlqa/utils.py @@ -0,0 +1,165 @@ +""" +Code based on Official evaluation script for the MLQA dataset. +Repo: https://github.com/facebookresearch/MLQA/blob/main/mlqa_evaluation_v1.py +""" + +import re +import string +import sys +import unicodedata +from collections import Counter + +import datasets + + +PUNCT = { + chr(i) + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") +}.union(string.punctuation) +WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"] +MIXED_SEGMENTATION_LANGS = ["zh"] + + +def whitespace_tokenize(text): + return text.split() + + +def mixed_segmentation(text): + segs_out = [] + temp_str = "" + for char in text: + if re.search(r"[\u4e00-\u9fa5]", char) or char in PUNCT: + if temp_str != "": + ss = whitespace_tokenize(temp_str) + segs_out.extend(ss) + temp_str = "" + segs_out.append(char) + else: + temp_str += char + + if temp_str != "": + ss = whitespace_tokenize(temp_str) + segs_out.extend(ss) + + return segs_out + + +def normalize_answer(s, lang): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text, lang): + if lang == "en": + return re.sub(r"\b(a|an|the)\b", " ", text) + elif lang == "es": + return re.sub(r"\b(un|una|unos|unas|el|la|los|las)\b", " ", text) + elif lang == "hi": + return text # Hindi does not have formal articles + elif lang == "vi": + return re.sub(r"\b(của|là|cái|chiếc|những)\b", " ", text) + elif lang == "de": + return re.sub( + r"\b(ein|eine|einen|einem|eines|einer|der|die|das|den|dem|des)\b", + " ", + text, + ) + elif lang == "ar": + return re.sub(r"\sال^|ال", " ", text) + elif lang == "zh": + return text # Chinese does not have formal articles + else: + raise Exception("Unknown Language {}".format(lang)) + + def white_space_fix(text, lang): + if lang in WHITESPACE_LANGS: + tokens = whitespace_tokenize(text) + elif lang in MIXED_SEGMENTATION_LANGS: + tokens = mixed_segmentation(text) + else: + raise Exception("Unknown Language {}".format(lang)) + return " ".join([t for t in tokens if t.strip() != ""]) + + def remove_punc(text): + return "".join(ch for ch in text if ch not in PUNCT) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)), lang), lang) + + +def f1_score(prediction, ground_truth, lang): + prediction_tokens = normalize_answer(prediction, lang).split() + ground_truth_tokens = normalize_answer(ground_truth, lang).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth, lang): + return normalize_answer(prediction, lang) == normalize_answer(ground_truth, lang) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, lang): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth, lang) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + out_doc = { + "context": doc["context"], + "question": doc["question"], + "answers": doc["answers"]["text"], + } + return out_doc + + return dataset.map(_process_doc) + + +# Base function +def process_results_lang(doc, results, lang): + ground_truths = doc["answers"] + prediction = results[0].strip() + exact_match = metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths, lang + ) + f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths, lang) + return {"exact_match": exact_match, "f1": f1} + + +# Language Wrapper functions +def process_results_en(doc, results): + return process_results_lang(doc, results, "en") + + +def process_results_es(doc, results): + return process_results_lang(doc, results, "es") + + +def process_results_hi(doc, results): + return process_results_lang(doc, results, "hi") + + +def process_results_vi(doc, results): + return process_results_lang(doc, results, "vi") + + +def process_results_de(doc, results): + return process_results_lang(doc, results, "de") + + +def process_results_ar(doc, results): + return process_results_lang(doc, results, "ar") + + +def process_results_zh(doc, results): + return process_results_lang(doc, results, "zh")