From 219bca17987f2520ca2af7d381cdc7ada3693c17 Mon Sep 17 00:00:00 2001 From: Ali Maredia Date: Wed, 25 Sep 2024 23:52:29 -0400 Subject: [PATCH 1/2] return overall_score from MTBenchBranch.generate_judgement() This allows the overall_score to be shown by callers of the library along with qa pairs and the error rate. This commit changes what a function in the library returns and thus is not backwards compatible. Signed-off-by: Ali Maredia --- src/instructlab/eval/mt_bench.py | 6 ++++-- tests/test_branch_judge_answers.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/instructlab/eval/mt_bench.py b/src/instructlab/eval/mt_bench.py index da0d60f..22d8c70 100644 --- a/src/instructlab/eval/mt_bench.py +++ b/src/instructlab/eval/mt_bench.py @@ -260,10 +260,12 @@ def judge_answers( serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. Returns: + overall_score overall score from the evaluation qa_pairs Question and answer pairs (with scores) from the evaluation + error_rate percentage of questions dropped due to errors during evaluation """ logger.debug(locals()) - _, qa_pairs, _, error_rate = mt_bench_judgment.generate_judgment( + overall_score, qa_pairs, _, error_rate = mt_bench_judgment.generate_judgment( self.model_name, self.judge_model_name, server_url, @@ -275,4 +277,4 @@ def judge_answers( bench_name="mt_bench_branch", merge_system_user_message=self.merge_system_user_message, ) - return qa_pairs, error_rate + return overall_score, qa_pairs, error_rate diff --git a/tests/test_branch_judge_answers.py b/tests/test_branch_judge_answers.py index 5b2e566..51705f7 100755 --- a/tests/test_branch_judge_answers.py +++ b/tests/test_branch_judge_answers.py @@ -10,7 +10,11 @@ "../taxonomy", "main", ) -qa_pairs, error_rate = mt_bench_branch.judge_answers("http://localhost:8000/v1") +overall_score, qa_pairs, error_rate = mt_bench_branch.judge_answers( + "http://localhost:8000/v1" +) + +print(f"Overall Score: {overall_score}") print(f"Error Rate: {error_rate}") print(f"QA Pair 0:") pprint.pprint(qa_pairs[0]) From b22b40bed2a489e53862d22a46602723018ce8c8 Mon Sep 17 00:00:00 2001 From: Ali Maredia Date: Fri, 27 Sep 2024 15:15:04 -0400 Subject: [PATCH 2/2] update flag for basic workflow tests Signed-off-by: Ali Maredia --- .github/workflows/e2e-nvidia-t4-x1.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/e2e-nvidia-t4-x1.yml b/.github/workflows/e2e-nvidia-t4-x1.yml index d7ea0c3..755b664 100644 --- a/.github/workflows/e2e-nvidia-t4-x1.yml +++ b/.github/workflows/e2e-nvidia-t4-x1.yml @@ -142,7 +142,7 @@ jobs: working-directory: ./instructlab run: | . venv/bin/activate - ./scripts/basic-workflow-tests.sh -m + ./scripts/basic-workflow-tests.sh -msq stop-runner: name: Stop external EC2 runner