-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_retrieval_chrf.py
39 lines (33 loc) · 1.18 KB
/
evaluate_retrieval_chrf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import evaluate
from tqdm import tqdm
from evaluation.retrieval.retrieval_datasets import NTREXDataset
chrf = evaluate.load("chrf")
dataset = NTREXDataset()
lang_pairs = [
("deu", "gsw-BE"),
("deu", "gsw-ZH"),
]
for src_lang, tgt_lang in lang_pairs:
print(f"Language pair: {src_lang} -> {tgt_lang}")
src_sentences = dataset.get_sentences(src_lang)
tgt_sentences = dataset.get_sentences(tgt_lang)
num_correct = 0
num_total = 0
for i, src_sentence in enumerate(tqdm(src_sentences)):
max_score = 0
best_j = None
for j, tgt_sentence in enumerate(tgt_sentences):
result = chrf.compute(predictions=[src_sentence], references=[[tgt_sentence]])
score = result["score"]
if score > max_score:
max_score = score
best_j = j
is_correct = best_j == i
num_correct += is_correct
num_total += 1
if not is_correct:
print(f"Source: {src_sentence} (i={i})")
print(f"Target: {tgt_sentences[i]} (i={i})")
print(f"Prediction: {tgt_sentences[best_j]} (j={best_j})")
print()
print(f"Accuracy: {num_correct / num_total}")