Skip to content

Commit

Permalink
Merge pull request #138 from alimaredia/mtbench-branch-judgement-retu…
Browse files Browse the repository at this point in the history
…rn-overall-score

return overall_score from MTBenchBranch.generate_judgement()
  • Loading branch information
danmcp authored Sep 28, 2024
2 parents fa22ef5 + b22b40b commit 40cc370
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/e2e-nvidia-t4-x1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
working-directory: ./instructlab
run: |
. venv/bin/activate
./scripts/basic-workflow-tests.sh -m
./scripts/basic-workflow-tests.sh -msq
stop-runner:
name: Stop external EC2 runner
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/eval/mt_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,12 @@ def judge_answers(
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
Returns:
overall_score overall score from the evaluation
qa_pairs Question and answer pairs (with scores) from the evaluation
error_rate percentage of questions dropped due to errors during evaluation
"""
logger.debug(locals())
_, qa_pairs, _, error_rate = mt_bench_judgment.generate_judgment(
overall_score, qa_pairs, _, error_rate = mt_bench_judgment.generate_judgment(
self.model_name,
self.judge_model_name,
server_url,
Expand All @@ -261,4 +263,4 @@ def judge_answers(
bench_name="mt_bench_branch",
merge_system_user_message=self.merge_system_user_message,
)
return qa_pairs, error_rate
return overall_score, qa_pairs, error_rate
6 changes: 5 additions & 1 deletion tests/test_branch_judge_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
"../taxonomy",
"main",
)
qa_pairs, error_rate = mt_bench_branch.judge_answers("http://localhost:8000/v1")
overall_score, qa_pairs, error_rate = mt_bench_branch.judge_answers(
"http://localhost:8000/v1"
)

print(f"Overall Score: {overall_score}")
print(f"Error Rate: {error_rate}")
print(f"QA Pair 0:")
pprint.pprint(qa_pairs[0])
Expand Down

0 comments on commit 40cc370

Please sign in to comment.