-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
86 lines (68 loc) · 2.13 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from senters import Bunkai, Ginza, Hasami, Kuzukiri, Pysbd, Rhoknp, Sengiri, Senter
logger = logging.getLogger(__name__)
here = Path(__file__).parent
data_dir = here.joinpath("data")
def post_process(sents: list[str]) -> list[str]:
return [sent.strip() for sent in sents if sent.strip()]
@dataclass
class Example:
input: str
output: list[str]
def __post_init__(self) -> None:
self.output = post_process(self.output)
def benchmark(senter: Senter, examples: list[Example]) -> None:
"""Perform benchmarking.
Args:
senter: A sentence segmentation tool to be tested.
examples: Examples used for benchmarking.
"""
start = time.time()
predictions = []
for example in examples:
prediction = post_process(senter(example.input))
predictions.append(prediction)
end = time.time()
# Get elapsed time
elapsed_time = end - start
# Calculate F1 (micro-average)
tp, fp, fn = 0, 0, 0
for prediction, example in zip(predictions, examples):
output = example.output
tp += sum(p in output for p in prediction)
fp += sum(p not in output for p in prediction)
fn += sum(o not in prediction for o in output)
pre = tp / (tp + fp)
rec = tp / (tp + fn)
if pre + rec > 0:
f1 = 100 * 2 * pre * rec / (pre + rec)
else:
f1 = 0.0
print(f"{senter.name}\t{f1:5.1f} (Elapsed time: {elapsed_time:.2f})")
def main() -> None:
senter_list = [
Rhoknp(),
Sengiri(),
Hasami(),
Kuzukiri(),
Pysbd(),
Bunkai(),
Ginza(),
]
for data_file in data_dir.glob("*.jsonl"):
print("#", data_file.absolute())
with data_file.open() as f:
examples = [Example(**json.loads(line)) for line in f]
for senter in senter_list:
benchmark(senter, examples)
print("---")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
main()